opr_impl.h 8.2 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/naive/elemwise_multi_type/opr_impl.h
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#pragma once

14
#include "megdnn/tensor_iter.h"
15
#include "src/common/elemwise_multi_type/opr_impl_helper.h"
16
#include "src/naive/handle.h"
17 18 19 20 21

namespace megdnn {
namespace naive {

class ElemwiseMultiTypeImpl : public ElemwiseMultiTypeImplHelper {
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40
    template <typename KernImpl, typename src_ctype, typename ElemParam>
    void dispatch_add_qint_op_dst(const ElemParam& param, const TensorND& dst) {
        switch (dst.layout.dtype.enumv()) {
#define cb(_dt)                                                                     \
    case DTypeTrait<_dt>::enumv:                                                    \
        dispatch_add_qint_op<KernImpl, src_ctype, typename DTypeTrait<_dt>::ctype>( \
                param, dst);                                                        \
        break;
            MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
            MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
#undef cb

            default:
                megdnn_assert(
                        0, "not support %s %s\n", param[0].layout.dtype.name(),
                        dst.layout.dtype.name());
        }
    }

41
    template <typename KernImpl, typename ElemParam>
42 43 44 45 46 47 48 49 50 51
    void dispatch_qint_op_dtype(const ElemParam& param, const TensorND& dst) {
        switch (param[0].layout.dtype.enumv()) {
#define cb(_dt)                                                                    \
    case DTypeTrait<_dt>::enumv:                                                   \
        dispatch_add_qint_op_dst<                                                  \
                KernImpl, typename DTypeTrait<_dt>::ctype, ElemParam>(param, dst); \
        break;
            MEGDNN_FOREACH_QUANTIZED_DTYPE(cb)
            MEGDNN_FOREACH_QUANTIZED_LOWBIT_DTYPE(cb)
#undef cb
52

53 54 55 56
            default:
                megdnn_assert_internal(0);
        }
    }
57 58

    template <typename KernImpl, typename src_ctype, typename dst_ctype>
M
Megvii Engine Team 已提交
59
    void dispatch_add_qint_op(
60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79
            const ElemwiseOpParamN<1>& param, const TensorND& dst_tensor) {
        auto src = param[0];
        auto size = param.size;
        auto work = [src, size, dst_tensor]() {
            auto iA = tensor_iter_valonly<src_ctype>(src).begin();
            auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin();

            auto param0 =
                    src.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
            auto dst_param = dst_tensor.layout.dtype
                                     .param<typename DTypeTrait<dst_ctype>::dtype>();
            for (size_t i = 0; i < size; i++) {
                src_ctype a = *iA;
                *pD = dst_param.quantize(KernImpl::apply(param0.dequantize(a)));
                ++iA;
                ++pD;
            }
        };
        MEGDNN_DISPATCH_CPU_KERN_OPR(work());
    }
80 81

    template <typename KernImpl, typename src_ctype, typename dst_ctype>
M
Megvii Engine Team 已提交
82
    void dispatch_add_qint_op(
83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
            const ElemwiseOpParamN<2>& param, const TensorND& dst_tensor) {
        auto size = param.size;
        auto src0 = param[0];
        auto src1 = param[1];
        auto work = [src0, src1, size, dst_tensor]() {
            // This is needed as these iterators are captured as const value.
            auto iA = tensor_iter_valonly<src_ctype>(src0).begin();
            auto iB = tensor_iter_valonly<src_ctype>(src1).begin();
            auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin();
            auto param0 =
                    src0.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
            auto param1 =
                    src1.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
            auto dst_param = dst_tensor.layout.dtype
                                     .param<typename DTypeTrait<dst_ctype>::dtype>();
            for (size_t i = 0; i < size; i++) {
                src_ctype a = *iA;
                src_ctype b = *iB;
                *pD = dst_param.quantize(
                        KernImpl::apply(param0.dequantize(a), param1.dequantize(b)));
                ++iA;
                ++iB;
                ++pD;
            }
        };
        MEGDNN_DISPATCH_CPU_KERN_OPR(work());
    }
110 111

    template <typename KernImpl, typename src_ctype, typename dst_ctype>
M
Megvii Engine Team 已提交
112
    void dispatch_add_qint_op(
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146
            const ElemwiseOpParamN<3>& param, const TensorND& dst_tensor) {
        auto size = param.size;
        auto src0 = param[0];
        auto src1 = param[1];
        auto src2 = param[2];
        auto work = [src0, src1, src2, size, dst_tensor]() {
            // This is needed as these iterators are captured as const value.
            auto iA = tensor_iter_valonly<src_ctype>(src0).begin();
            auto iB = tensor_iter_valonly<src_ctype>(src1).begin();
            auto iC = tensor_iter_valonly<src_ctype>(src2).begin();
            auto pD = tensor_iter_valonly<dst_ctype>(dst_tensor).begin();
            auto param0 =
                    src0.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
            auto param1 =
                    src1.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
            auto param2 =
                    src2.layout.dtype.param<typename DTypeTrait<src_ctype>::dtype>();
            auto dst_param = dst_tensor.layout.dtype
                                     .param<typename DTypeTrait<dst_ctype>::dtype>();
            for (size_t i = 0; i < size; i++) {
                src_ctype a = *iA;
                src_ctype b = *iB;
                src_ctype c = *iC;
                *pD = dst_param.quantize(KernImpl::apply(
                        param0.dequantize(a), param1.dequantize(b),
                        param2.dequantize(c)));
                ++iA;
                ++iB;
                ++iC;
                ++pD;
            }
        };
        MEGDNN_DISPATCH_CPU_KERN_OPR(work());
    }
147 148 149

protected:
    template <typename ctype>
150 151
    void dispatch_fma3_iXxf32xf32xi8(
            const ElemwiseOpParamN<3>& param, const TensorND& dst);
152 153

    template <typename ctype, typename dst_ctype>
M
Megvii Engine Team 已提交
154
    void dispatch_round_shr_saturate_iXxi8xiX(
155
            const ElemwiseOpParamN<2>& param, const TensorND& dst);
156 157 158

    template <typename ctype>
    void dispatch_fuse_add_rmulh_round_shr_saturate(
159
            const ElemwiseOpParamN<6>& param, const TensorND& dst);
160

M
Megvii Engine Team 已提交
161
    void on_fuse_mul_add3_int16x32x32x32(
162
            const ElemwiseOpParamN<3>& param, const TensorND& dst) override;
M
Megvii Engine Team 已提交
163
    void on_fuse_mul_add3_iXxf32xf32xi8(
164
            const ElemwiseOpParamN<3>& param, const TensorND& dst) override;
M
Megvii Engine Team 已提交
165
    void on_round_shr_saturate_iXxi8xi8(
166
            const ElemwiseOpParamN<2>& param, const TensorND& dst) override;
167
    void on_fuse_add_rmulh_round_shr_saturate_int16x16x16x8(
168
            const ElemwiseOpParamN<6>& param, const TensorND& dst) override;
169
    void on_fuse_add_rmulh_round_shr_saturate_int32x32x32x8(
170
            const ElemwiseOpParamN<6>& param, const TensorND& dst) override;
M
Megvii Engine Team 已提交
171
    void on_round_shr_saturate_iXxi8xi16(
172
            const ElemwiseOpParamN<2>& param, const TensorND& dst) override;
173 174 175 176 177 178
    void on_fuse_mul_add3_int16xf32xf32xf32(
            const ElemwiseOpParamN<3>& param, const TensorND& dst) override;
    void on_mul_int16xf32xf32(
            const ElemwiseOpParamN<2>& param, const TensorND& dst) override;
    void on_fuse_mul_add3_uint8xf32xf32xf32(
            const ElemwiseOpParamN<3>& param, const TensorND& dst) override;
179

M
Megvii Engine Team 已提交
180 181 182
    void on_quantized_mode(
            const ElemwiseOpParamN<1>& param, const TensorND& dst,
            Elemwise::Mode mode) override;
183

M
Megvii Engine Team 已提交
184 185 186
    void on_quantized_mode(
            const ElemwiseOpParamN<2>& param, const TensorND& dst,
            Elemwise::Mode mode) override;
187

M
Megvii Engine Team 已提交
188 189 190
    void on_quantized_mode(
            const ElemwiseOpParamN<3>& param, const TensorND& dst,
            Elemwise::Mode mode) override;
191 192 193 194 195 196 197 198 199

public:
    using ElemwiseMultiTypeImplHelper::ElemwiseMultiTypeImplHelper;
};

}  // namespace naive
}  // namespace megdnn

// vim: syntax=cpp.doxygen