opr_impl.cpp 14.4 KB
Newer Older
1 2 3 4
#include "src/naive/conv_bias/opr_impl.h"
#include "src/naive/convolution/helper.h"

#include <cstring>
5
#include "megdnn/algorithm_cache.h"
6
#include "megdnn/dtype.h"
7 8
#include "src/common/conv_bias.h"
#include "src/common/opr_delegate.h"
9 10 11 12 13 14 15 16 17 18
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include "src/naive/lowbit_utils.h"

#include "midout.h"
MIDOUT_DECL(megdnn_naive_conv_bias_fwd)

namespace megdnn {
namespace naive {

19 20 21
//! Only used for naive implementation. DO NOT use the following function in
//! other backends.
void handle_z_inp_and_activation_naive(
M
Megvii Engine Team 已提交
22 23
        param::ConvBias::NonlineMode nonline_mode, const TensorND& conv_bias_tensor,
        const TensorND& z_tensor, const TensorND& dst_tensor, dt_byte* workspace_ptr) {
24
    auto res = dst_tensor, z_float = z_tensor;
25
    //! create naive inplace handle
26 27 28
    auto handle = inplace_cpu_handle(2);
    if (z_tensor.layout.ndim > 0 &&
        z_tensor.layout.dtype.category() != DTypeCategory::FLOAT) {
M
Megvii Engine Team 已提交
29
        dt_byte *res_float_workspace_ptr = nullptr, *z_float_workspace_ptr = nullptr;
30 31
        megdnn_assert(z_tensor.layout.eq_shape(dst_tensor.layout));
        res_float_workspace_ptr = workspace_ptr;
M
Megvii Engine Team 已提交
32 33 34 35 36 37 38 39
        z_float_workspace_ptr =
                res_float_workspace_ptr +
                TensorLayout{z_tensor.layout, dtype::Float32()}.span().dist_byte();
        res = TensorND{
                res_float_workspace_ptr,
                TensorLayout{dst_tensor.layout, dtype::Float32()}};
        z_float = TensorND{
                z_float_workspace_ptr, TensorLayout{z_tensor.layout, dtype::Float32()}};
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
    }
    // ====================sfb + z_tensor=====================
    if (z_tensor.layout.ndim > 0) {
        if (z_tensor.layout.dtype.category() != DTypeCategory::FLOAT) {
            auto&& type_cvt = handle->create_operator<TypeCvt>();
            type_cvt->exec(conv_bias_tensor, res);
            type_cvt->exec(z_tensor, z_float);
        }
        auto add_opr = handle->create_operator<ElemwiseForward>();
        add_opr->param().mode = Elemwise::Param::Mode::ADD;
        add_opr->exec({res, z_float}, res);
    } else {
        res = conv_bias_tensor;
    }

    using NonlineMode = param::ConvBias::NonlineMode;

    switch (nonline_mode) {
M
Megvii Engine Team 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
#define cb(_mode)                                                               \
    case NonlineMode::_mode: {                                                  \
        if (res.layout.dtype.category() != DTypeCategory::QUANTIZED) {          \
            auto nonlinear = handle->create_operator<ElemwiseForward>();        \
            nonlinear->param().mode = Elemwise::Param::Mode::_mode;             \
            if (res.layout.dtype == dst_tensor.layout.dtype) {                  \
                nonlinear->exec({res}, dst_tensor);                             \
            } else {                                                            \
                nonlinear->exec({res}, res);                                    \
                handle->create_operator<TypeCvt>()->exec(res, dst_tensor);      \
            }                                                                   \
        } else {                                                                \
            auto nonlinear = handle->create_operator<ElemwiseMultiType>();      \
            nonlinear->param().mode = ElemwiseMultiType::Param::Mode::Q##_mode; \
            nonlinear->exec({res}, dst_tensor);                                 \
        }                                                                       \
        break;                                                                  \
75 76 77 78 79
    }
        cb(RELU);
        cb(H_SWISH);
#undef cb
        case NonlineMode::SIGMOID: {
M
Megvii Engine Team 已提交
80
            megdnn_assert(res.layout.dtype.category() != DTypeCategory::QUANTIZED);
81 82 83
            auto nonlinear = handle->create_operator<ElemwiseForward>();
            nonlinear->param().mode = Elemwise::Param::Mode::SIGMOID;
            nonlinear->exec({res}, res);
84
            if (res.raw_ptr() != dst_tensor.raw_ptr()) {
85 86 87 88 89
                handle->create_operator<TypeCvt>()->exec(res, dst_tensor);
            }
            break;
        }
        case NonlineMode::IDENTITY: {
90
            if (res.raw_ptr() != dst_tensor.raw_ptr()) {
91 92 93 94 95 96 97 98 99
                handle->create_operator<TypeCvt>()->exec(res, dst_tensor);
            }
            break;
        }
        default:
            megdnn_assert(false);
    }
}

100 101 102 103 104 105 106 107 108 109 110
namespace convolution {

template <>
void forward_bias<dt_quint4, dt_quint4, dt_qint32, dt_qint32>(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
        _megdnn_tensor_out dst, dt_byte* workspace_ptr,
        const ConvBiasForward::CanonizedFilterMeta& filter_meta) {
    auto convert_layout = [](const TensorLayout& layout) {
        auto ret = layout;
        auto param = layout.dtype.param<dtype::Quantized4Asymm>();
        ret.dtype = dtype::Quantized8Asymm(param.scale, param.zero_point);
111
        ret.format = TensorFormat(ret.dtype);
112
        ret.init_contiguous_stride();
113 114 115
        return ret;
    };
    TensorND new_src = {workspace_ptr, convert_layout(src.layout)};
M
Megvii Engine Team 已提交
116 117 118
    TensorND new_flt = {
            workspace_ptr + new_src.layout.span().dist_byte(),
            convert_layout(filter.layout)};
119 120 121 122 123 124 125 126

    uint4_to_uint8(src, new_src);
    uint4_to_uint8(filter, new_flt);
    auto new_filter_meta = filter_meta;
    new_filter_meta.dtype = new_flt.layout.dtype;
    forward_bias<dt_quint8, dt_quint8, dt_qint32, dt_qint32>(
            new_src, new_flt, bias, dst, nullptr, new_filter_meta);
}
127 128 129 130 131 132 133 134 135 136 137

template <>
void forward_bias<dt_qint4, dt_qint4, dt_qint32, dt_qint32>(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
        _megdnn_tensor_out dst, dt_byte* workspace_ptr,
        const ConvBiasForward::CanonizedFilterMeta& filter_meta) {
    auto convert_layout = [](const TensorLayout& layout) {
        auto ret = layout;
        auto param = layout.dtype.param<dtype::QuantizedS4>();
        ret.dtype = dtype::QuantizedS8(param.scale);
        ret.format = TensorFormat(ret.dtype);
138
        ret.init_contiguous_stride();
139 140 141
        return ret;
    };
    TensorND new_src = {workspace_ptr, convert_layout(src.layout)};
M
Megvii Engine Team 已提交
142 143 144
    TensorND new_flt = {
            workspace_ptr + new_src.layout.span().dist_byte(),
            convert_layout(filter.layout)};
145 146 147 148 149 150 151
    int4_to_int8(src, new_src);
    int4_to_int8(filter, new_flt);
    auto new_filter_meta = filter_meta;
    new_filter_meta.dtype = new_flt.layout.dtype;
    forward_bias<dt_qint8, dt_qint8, dt_qint32, dt_qint32>(
            new_src, new_flt, bias, dst, nullptr, new_filter_meta);
}
152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174

template <>
void forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
        _megdnn_tensor_out dst, dt_byte* workspace_ptr,
        const ConvBiasForward::CanonizedFilterMeta& filter_meta) {
    auto convert_layout_src = [](const TensorLayout& layout) {
        auto ret = layout;
        auto param = layout.dtype.param<dtype::Quantized4Asymm>();
        ret.dtype = dtype::QuantizedS8(param.scale);
        ret.format = TensorFormat(ret.dtype);
        ret.init_contiguous_stride();
        return ret;
    };
    auto convert_layout_flt = [](const TensorLayout& layout) {
        auto ret = layout;
        auto param = layout.dtype.param<dtype::QuantizedS4>();
        ret.dtype = dtype::QuantizedS8(param.scale);
        ret.format = TensorFormat(ret.dtype);
        ret.init_contiguous_stride();
        return ret;
    };
    TensorND new_src = {workspace_ptr, convert_layout_src(src.layout)};
M
Megvii Engine Team 已提交
175 176 177
    TensorND new_flt = {
            workspace_ptr + new_src.layout.span().dist_byte(),
            convert_layout_flt(filter.layout)};
178 179 180 181 182 183 184
    uint4_to_int8(src, new_src);
    int4_to_int8(filter, new_flt);
    auto new_filter_meta = filter_meta;
    new_filter_meta.dtype = new_flt.layout.dtype;
    forward_bias<dt_qint8, dt_qint8, dt_qint32, dt_qint32>(
            new_src, new_flt, bias, dst, nullptr, new_filter_meta);
}
185 186
}  // namespace convolution

M
Megvii Engine Team 已提交
187 188 189
size_t ConvBiasForwardImpl::get_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst, const PreprocessedFilter*) {
190 191 192 193 194
    size_t float_workspace_size = 0;

    if (z.ndim > 0 && z.dtype.category() != DTypeCategory::FLOAT) {
        megdnn_assert(z.eq_shape(dst));
        // (w * f + b).astype(float) + (z).astype(float)
M
Megvii Engine Team 已提交
195
        float_workspace_size = 2 * TensorLayout{z, dtype::Float32()}.span().dist_byte();
196
    }
M
Megvii Engine Team 已提交
197

198 199 200 201 202 203
    if ((src.dtype.enumv() == DTypeEnum::Quantized4Asymm ||
         src.dtype.enumv() == DTypeEnum::QuantizedS4) &&
        bias.dtype.enumv() == DTypeEnum::QuantizedS32) {
        float_workspace_size +=
                (src.total_nr_elems() + flt.total_nr_elems()) * sizeof(uint8_t);
    }
204 205

    if (bias.dtype.enumv() != dst.dtype.enumv()) {
M
Megvii Engine Team 已提交
206
        float_workspace_size += TensorLayout{dst, bias.dtype}.span().dist_byte();
207 208 209 210
    }
    return float_workspace_size;
}

M
Megvii Engine Team 已提交
211 212 213 214
void ConvBiasForwardImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
        _megdnn_tensor_in z, _megdnn_tensor_out dst,
        const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
215
    MIDOUT_BEGIN(megdnn_naive_conv_bias_fwd) {
216
        dt_byte* workspace_ptr = workspace.raw_ptr;
217 218
        // ============================w * f + b================================

219 220 221
        auto filter_meta = check_exec_allow_noncontiguous(
                src.layout, filter.layout, bias.layout, z.layout, dst.layout,
                workspace.size, preprocessed_filter);
222 223 224
        auto sfb = dst;
        if (bias.layout.dtype.enumv() != dst.layout.dtype.enumv()) {
            // intermediate result
M
Megvii Engine Team 已提交
225
            sfb = TensorND{workspace_ptr, TensorLayout{dst.layout, bias.layout.dtype}};
226 227
            workspace_ptr += sfb.layout.span().dist_byte();
        }
228
#define DISPATCH_RAW(in_dt, flt_dt, bias_dt, out_dt, cmode, func)              \
M
Megvii Engine Team 已提交
229 230 231 232 233 234
    else if (                                                                  \
            src.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv &&     \
            filter.layout.dtype.enumv() == DTypeTrait<dtype::flt_dt>::enumv && \
            bias.layout.dtype.enumv() == DTypeTrait<dtype::bias_dt>::enumv &&  \
            sfb.layout.dtype.enumv() == DTypeTrait<dtype::out_dt>::enumv &&    \
            param().compute_mode == Param::ComputeMode::cmode) {               \
235 236 237
        MEGDNN_DISPATCH_CPU_KERN_OPR(                                          \
                func(src, filter, bias, sfb, workspace_ptr, filter_meta));     \
    }
M
Megvii Engine Team 已提交
238 239 240 241 242 243 244
#define DISPATCH(in_dt, out_dt)                                                       \
    DISPATCH_RAW(                                                                     \
            in_dt, in_dt, out_dt, out_dt, DEFAULT,                                    \
            (convolution::forward_bias<                                               \
                    DTypeTrait<dtype::in_dt>::ctype, DTypeTrait<dtype::in_dt>::ctype, \
                    DTypeTrait<dtype::out_dt>::ctype,                                 \
                    DTypeTrait<dtype::out_dt>::ctype>))
245 246
        if (0) {
        }
247 248 249 250
        DISPATCH(Float32, Float32)
        DISPATCH(Int8, Int16)
        DISPATCH(Int8, Int32)
        DISPATCH(QuantizedS8, QuantizedS32)
251
        DISPATCH(QuantizedS8, Float32)
252 253
        DISPATCH(Quantized8Asymm, QuantizedS32)
        DISPATCH(Quantized4Asymm, QuantizedS32)
M
Megvii Engine Team 已提交
254 255 256
        DISPATCH_RAW(
                QuantizedS8, QuantizedS8, QuantizedS32, QuantizedS32, FLOAT32,
                (convolution::forward_bias<dt_int8, dt_int8, dt_int32, dt_int32>))
257
        DISPATCH(QuantizedS4, QuantizedS32)
M
Megvii Engine Team 已提交
258 259 260
        DISPATCH_RAW(
                Quantized4Asymm, QuantizedS4, QuantizedS32, QuantizedS32, DEFAULT,
                (convolution::forward_bias<dt_quint4, dt_qint4, dt_qint32, dt_qint32>))
261 262 263
        DISPATCH_RAW(
                QuantizedS1, QuantizedS1, QuantizedS32, QuantizedS32, FLOAT32,
                (convolution::forward_bias<dt_qint1, dt_qint1, dt_qint32, dt_qint32>))
264 265
#if !MEGDNN_DISABLE_FLOAT16
        DISPATCH(Float16, Float16)
M
Megvii Engine Team 已提交
266 267 268 269 270 271 272 273
        DISPATCH_RAW(
                Float16, Float16, Float16, Float16, FLOAT32,
                (convolution::forward_bias<
                        dt_float16, dt_float16, dt_float16, dt_float32>))
        DISPATCH_RAW(
                BFloat16, BFloat16, BFloat16, BFloat16, FLOAT32,
                (convolution::forward_bias<
                        dt_bfloat16, dt_bfloat16, dt_bfloat16, dt_float32>))
274 275 276 277 278 279 280 281 282 283
#endif
        else {
            megdnn_throw(ssprintf(
                    "unsupported naive ConvBias(%s, %s, %s, %s) -> %s",
                    src.layout.dtype.name(), filter.layout.dtype.name(),
                    bias.layout.dtype.name(), z.layout.dtype.name(),
                    dst.layout.dtype.name()));
        }
#undef DISPATCH
#undef DISPATCH_RAW
284 285
        MEGDNN_DISPATCH_CPU_KERN_OPR(handle_z_inp_and_activation_naive(
                param().nonlineMode, sfb, z, dst, workspace_ptr));
286 287 288 289
    }
    MIDOUT_END();
}

M
Megvii Engine Team 已提交
290 291 292
std::vector<ConvBiasForward::Algorithm*> ConvBiasForwardImpl::get_all_algorithms(
        const TensorLayout&, const TensorLayout&, const TensorLayout&,
        const TensorLayout&, const TensorLayout&) {
293
    return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()};
294 295
}

M
Megvii Engine Team 已提交
296 297 298
std::vector<ConvBiasForward::Algorithm*> ConvBiasForwardImpl::get_all_algorithms_safe(
        const TensorLayout&, const TensorLayout&, const TensorLayout&,
        const TensorLayout&, const TensorLayout&) {
299
    return {static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo()};
300 301 302 303 304 305
}

ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_heuristic(
        const TensorLayout& /* src */, const TensorLayout& /* filter */,
        const TensorLayout& /* bias */, const TensorLayout& /* z */,
        const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */,
M
Megvii Engine Team 已提交
306 307
        const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
    auto algo = static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo();
308
    algo->check_attribute(positive_attr, negative_attr);
309 310 311
    return algo;
}

312
ConvBiasForward::Algorithm* ConvBiasForwardImpl::get_algorithm_from_desc(
313
        const AlgorithmDesc& desc) {
M
Megvii Engine Team 已提交
314
    Algorithm* ret = static_cast<HandleImpl*>(handle())->default_conv_bias_fwd_algo();
315 316 317 318
    megdnn_assert(desc == ret->info().desc);
    return ret;
}

319 320 321 322 323 324 325 326
const char* ConvBiasForwardImpl::get_algorithm_set_name() const {
    return "DEFAULT";
}

}  // namespace naive
}  // namespace megdnn

// vim: syntax=cpp.doxygen