opr_impl.cpp 14.8 KB
Newer Older
M
Megvii Engine Team 已提交
1
#include "./opr_impl.h"
2
#include "./helper.h"
3

4
#include "megdnn/algorithm_cache.h"
5 6
#include "megdnn/dtype.h"
#include "megdnn/tensor_iter.h"
M
Megvii Engine Team 已提交
7 8
#include "src/common/utils.h"
#include "src/naive/handle.h"
9 10 11 12 13 14 15 16 17

#include <cstring>

#include "midout.h"
MIDOUT_DECL(megdnn_naive_conv_fwd)

using namespace megdnn;
using namespace naive;

M
Megvii Engine Team 已提交
18 19 20
void ConvolutionForwardImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_out dst,
        const PreprocessedFilter* preprocessed_filter, _megdnn_workspace workspace) {
21
    MIDOUT_BEGIN(megdnn_naive_conv_fwd) {
M
Megvii Engine Team 已提交
22 23 24
        auto filter_meta = check_exec(
                src.layout, filter.layout, dst.layout, workspace.size,
                preprocessed_filter);
25
        using ComputeMode = Param::ComputeMode;
26 27 28 29 30 31 32 33 34 35 36 37 38 39
#define DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, cmode)      \
    do {                                                                  \
        using namespace dtype;                                            \
        if (src.layout.dtype.enumv() == DTypeTrait<in_dt>::enumv &&       \
            dst.layout.dtype.enumv() == DTypeTrait<out_dt>::enumv &&      \
            param().compute_mode == cmode) {                              \
            MEGDNN_DISPATCH_CPU_KERN_OPR(                                 \
                    (convolution::forward<in_ct, in_ct, out_ct, comp_ct>( \
                            src, filter, dst, filter_meta)););            \
            return;                                                       \
        }                                                                 \
    } while (0);
#define DISPATCH(in_dt, out_dt, in_ct, out_ct, comp_ct) \
    DISPATCH_CMODE(in_dt, out_dt, in_ct, out_ct, comp_ct, ComputeMode::DEFAULT)
M
Megvii Engine Team 已提交
40 41 42 43
#define cb(dt)                                                    \
    DISPATCH(                                                     \
            dt, dt, DTypeTrait<dt>::ctype, DTypeTrait<dt>::ctype, \
            DTypeTrait<dt>::ctype)
44
        MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
45
#undef cb
46 47 48
        DISPATCH(Int8, Int16, dt_int8, dt_int16, dt_int16);
        DISPATCH(Int8, Int32, dt_int8, dt_int32, dt_int32);
        DISPATCH(QuantizedS8, QuantizedS32, dt_int8, dt_int32, dt_int32);
M
Megvii Engine Team 已提交
49 50 51 52 53 54 55
        DNN_INC_FLOAT16(DISPATCH_CMODE(
                Float16, Float16, dt_float16, dt_float16, dt_float32,
                ComputeMode::FLOAT32));
        DNN_INC_FLOAT16(DISPATCH_CMODE(
                BFloat16, BFloat16, dt_bfloat16, dt_bfloat16, dt_float32,
                ComputeMode::FLOAT32));
        DISPATCH(Quantized8Asymm, QuantizedS32, dt_quint8, dt_qint32, dt_qint32);
56
        DISPATCH(QuantizedS8, QuantizedS8, dt_int8, dt_int8, dt_int32);
57
#undef DISPATCH
M
Megvii Engine Team 已提交
58 59 60 61
        megdnn_throw(ssprintf(
                "unsupported Conv(%s, %s) -> %s with cmode = %d",
                src.layout.dtype.name(), filter.layout.dtype.name(),
                dst.layout.dtype.name(), static_cast<int>(param().compute_mode)));
62 63
    }
    MIDOUT_END();
64 65
}

M
Megvii Engine Team 已提交
66 67 68
size_t ConvolutionBackwardDataImpl::get_workspace_in_bytes(
        const TensorLayout& filter, const TensorLayout& diff,
        const TensorLayout& grad) {
69 70 71 72 73
    size_t workspace_size = 0;
    auto flt_dt = filter.dtype.enumv();
    auto grad_dt = grad.dtype.enumv();
    auto diff_dt = diff.dtype.enumv();
#if !MEGDNN_DISABLE_FLOAT16
74
    if (flt_dt == DTypeEnum::Float16 || flt_dt == DTypeEnum::BFloat16) {
75 76 77 78 79 80 81
        megdnn_assert(flt_dt == grad_dt && flt_dt == diff_dt);
        workspace_size = grad.span().dist_elem() * dtype::Float32().size();
    }
#endif
    if ((flt_dt == DTypeEnum::Int8 || flt_dt == DTypeEnum::QuantizedS8) &&
        (diff_dt == DTypeEnum::Int8 || diff_dt == DTypeEnum::QuantizedS8) &&
        (grad_dt == DTypeEnum::Int8 || grad_dt == DTypeEnum::QuantizedS8)) {
M
Megvii Engine Team 已提交
82
        workspace_size = TensorLayout{grad, dtype::QuantizedS32()}.span().dist_byte();
83 84 85 86 87
    }

    return workspace_size;
}

M
Megvii Engine Team 已提交
88 89 90 91 92
void ConvolutionBackwardDataImpl::exec(
        _megdnn_tensor_in filter, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
        _megdnn_workspace workspace) {
    auto filter_meta =
            check_exec(filter.layout, diff.layout, grad.layout, workspace.size);
93 94 95 96 97 98 99 100 101 102 103 104 105 106 107
    using ComputeMode = Param::ComputeMode;
    auto cmode = param().compute_mode;
#define cb(dt)                                                              \
    do {                                                                    \
        if (filter.layout.dtype == dt() && cmode == ComputeMode::DEFAULT) { \
            using ctype = DTypeTrait<dt>::ctype;                            \
            MEGDNN_DISPATCH_CPU_KERN_OPR(                                   \
                    (convolution::backward_data<ctype, ctype, ctype>(       \
                            filter, diff, grad, filter_meta)););            \
            return;                                                         \
        }                                                                   \
    } while (0);
    MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
#undef cb
#if !MEGDNN_DISABLE_FLOAT16
M
Megvii Engine Team 已提交
108
    if (filter.layout.dtype == dtype::Float16() && cmode == ComputeMode::FLOAT32) {
109 110
        TensorND grad_fp32{
                workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}};
111 112 113 114 115 116 117 118
        auto&& type_cvt = handle()->create_operator<TypeCvt>();
        type_cvt->exec(grad, grad_fp32);
        MEGDNN_DISPATCH_CPU_KERN_OPR(
                (convolution::backward_data<dt_float16, dt_float16, dt_float32>(
                        filter, diff, grad_fp32, filter_meta)););
        type_cvt->exec(grad_fp32, grad);
        return;
    }
M
Megvii Engine Team 已提交
119
    if (filter.layout.dtype == dtype::BFloat16() && cmode == ComputeMode::FLOAT32) {
120 121
        TensorND grad_fp32{
                workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}};
122 123 124 125 126 127 128 129
        auto&& type_cvt = handle()->create_operator<TypeCvt>();
        type_cvt->exec(grad, grad_fp32);
        MEGDNN_DISPATCH_CPU_KERN_OPR(
                (convolution::backward_data<dt_bfloat16, dt_bfloat16, dt_float32>(
                        filter, diff, grad_fp32, filter_meta)););
        type_cvt->exec(grad_fp32, grad);
        return;
    }
130 131 132 133 134 135 136 137
#endif
    auto flt_dt = filter.layout.dtype.enumv();
    auto grad_dt = grad.layout.dtype.enumv();
    if ((flt_dt == DTypeEnum::Int8 || flt_dt == DTypeEnum::QuantizedS8) &&
        (grad_dt == DTypeEnum::Int8 || grad_dt == DTypeEnum::QuantizedS8)) {
        auto res = grad;

        auto resf_s = filter.layout.dtype.param<dtype::QuantizedS8>().scale *
M
Megvii Engine Team 已提交
138 139 140 141
                      diff.layout.dtype.param<dtype::QuantizedS8>().scale;
        res = TensorND{
                workspace.raw_ptr,
                TensorLayout{grad.layout, dtype::QuantizedS32(resf_s)}};
142 143 144 145 146 147 148 149 150 151 152 153 154 155
        MEGDNN_DISPATCH_CPU_KERN_OPR(
                (convolution::backward_data<dt_qint8, dt_qint8, dt_qint32>(
                        filter, diff, res, filter_meta)););
        handle()->create_operator<TypeCvt>()->exec(res, grad);

        return;
    }
    if ((flt_dt == DTypeEnum::Int8 || flt_dt == DTypeEnum::QuantizedS8) &&
        (grad_dt == DTypeEnum::Int32 || grad_dt == DTypeEnum::QuantizedS32)) {
        MEGDNN_DISPATCH_CPU_KERN_OPR(
                (convolution::backward_data<dt_int8, dt_int8, dt_int32>(
                        filter, diff, grad, filter_meta)););
        return;
    }
M
Megvii Engine Team 已提交
156
    if (flt_dt == DTypeEnum::Quantized8Asymm && grad_dt == DTypeEnum::QuantizedS32) {
157 158 159 160 161 162 163 164 165 166 167 168
        MEGDNN_DISPATCH_CPU_KERN_OPR(
                (convolution::backward_data<dt_quint8, dt_quint8, dt_qint32>(
                        filter, diff, grad, filter_meta)););
        return;
    }
    megdnn_throw(ssprintf(
            "unsupported ConvolutionBackwardData(%s, %s) -> %s with cmode = %d",
            filter.layout.dtype.name(), diff.layout.dtype.name(),
            grad.layout.dtype.name(), static_cast<int>(cmode)));
}

size_t ConvolutionBackwardFilterImpl::get_workspace_in_bytes(
M
Megvii Engine Team 已提交
169
        const TensorLayout& src, const TensorLayout& diff, const TensorLayout& grad) {
170 171 172 173 174
    size_t workspace_size = 0;
#if !MEGDNN_DISABLE_FLOAT16
    auto src_dt = src.dtype.enumv();
    auto grad_dt = grad.dtype.enumv();
    auto diff_dt = diff.dtype.enumv();
175
    if (src_dt == DTypeEnum::Float16 || src_dt == DTypeEnum::BFloat16) {
176 177 178 179 180 181 182 183
        megdnn_assert(src_dt == grad_dt && src_dt == diff_dt);
        workspace_size = grad.span().dist_elem() * dtype::Float32().size();
    }
#endif

    return workspace_size;
}

M
Megvii Engine Team 已提交
184 185 186 187
void ConvolutionBackwardFilterImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in diff, _megdnn_tensor_out grad,
        _megdnn_workspace workspace) {
    auto filter_meta = check_exec(src.layout, diff.layout, grad.layout, workspace.size);
188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205
    using ComputeMode = Param::ComputeMode;
    auto cmode = param().compute_mode;
#define cb(dt)                                                            \
    do {                                                                  \
        if (src.layout.dtype == dt() && cmode == ComputeMode::DEFAULT) {  \
            using ctype = DTypeTrait<dt>::ctype;                          \
            MEGDNN_DISPATCH_CPU_KERN(                                     \
                    static_cast<HandleImpl*>(handle()),                   \
                    convolution::backward_filter<                         \
                            ctype MEGDNN_COMMA ctype MEGDNN_COMMA ctype>( \
                            src, diff, grad, filter_meta););              \
            return;                                                       \
        }                                                                 \
    } while (0);
    MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb);
#undef cb
#if !MEGDNN_DISABLE_FLOAT16
    if (src.layout.dtype == dtype::Float16() && cmode == ComputeMode::FLOAT32) {
206 207
        TensorND grad_fp32{
                workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}};
208 209 210
        auto&& type_cvt = handle()->create_operator<TypeCvt>();
        type_cvt->exec(grad, grad_fp32);
        MEGDNN_DISPATCH_CPU_KERN_OPR(
M
Megvii Engine Team 已提交
211 212
                (convolution::backward_filter<dt_float16, dt_float16, dt_float32>(
                        src, diff, grad_fp32, filter_meta)););
213 214 215
        type_cvt->exec(grad_fp32, grad);
        return;
    }
M
Megvii Engine Team 已提交
216
    if (src.layout.dtype == dtype::BFloat16() && cmode == ComputeMode::FLOAT32) {
217 218
        TensorND grad_fp32{
                workspace.raw_ptr, TensorLayout{grad.layout, dtype::Float32()}};
219 220 221
        auto&& type_cvt = handle()->create_operator<TypeCvt>();
        type_cvt->exec(grad, grad_fp32);
        MEGDNN_DISPATCH_CPU_KERN_OPR(
M
Megvii Engine Team 已提交
222 223
                (convolution::backward_filter<dt_bfloat16, dt_bfloat16, dt_float32>(
                        src, diff, grad_fp32, filter_meta)););
224 225 226 227
        type_cvt->exec(grad_fp32, grad);
        return;
    }

228 229 230 231 232
#endif

    megdnn_assert_internal(0);
}

M
Megvii Engine Team 已提交
233 234 235
std::vector<ConvolutionForward::Algorithm*> ConvolutionForwardImpl::get_all_algorithms(
        const TensorLayout&, const TensorLayout&, const TensorLayout&) {
    return {static_cast<HandleImpl*>(handle())->default_conv_fwd_algo()};
236 237
}

M
Megvii Engine Team 已提交
238 239 240 241
std::vector<ConvolutionForward::Algorithm*> ConvolutionForwardImpl::
        get_all_algorithms_safe(
                const TensorLayout&, const TensorLayout&, const TensorLayout&) {
    return {static_cast<HandleImpl*>(handle())->default_conv_fwd_algo()};
242 243
}

244
ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_heuristic(
245 246
        const TensorLayout& /* src */, const TensorLayout& /* filter */,
        const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */,
M
Megvii Engine Team 已提交
247 248
        const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
    auto algo = static_cast<HandleImpl*>(handle())->default_conv_fwd_algo();
249
    algo->check_attribute(positive_attr, negative_attr);
250 251 252
    return algo;
}

253 254
ConvolutionForward::Algorithm* ConvolutionForwardImpl::get_algorithm_from_desc(
        const AlgorithmDesc& desc) {
M
Megvii Engine Team 已提交
255
    Algorithm* ret = static_cast<HandleImpl*>(handle())->default_conv_fwd_algo();
256 257 258 259
    megdnn_assert(desc == ret->info().desc);
    return ret;
}

M
Megvii Engine Team 已提交
260 261 262 263
std::vector<ConvolutionBackwardData::Algorithm*> ConvolutionBackwardDataImpl::
        get_all_algorithms(
                const TensorLayout&, const TensorLayout&, const TensorLayout&) {
    return {static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo()};
264 265
}

M
Megvii Engine Team 已提交
266 267 268 269
std::vector<ConvolutionBackwardData::Algorithm*> ConvolutionBackwardDataImpl::
        get_all_algorithms_safe(
                const TensorLayout&, const TensorLayout&, const TensorLayout&) {
    return {static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo()};
270 271
}

M
Megvii Engine Team 已提交
272 273 274 275 276 277 278
ConvolutionBackwardData::Algorithm* ConvolutionBackwardDataImpl::
        get_algorithm_heuristic(
                const TensorLayout& /* filter */, const TensorLayout& /* diff */,
                const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
                const AlgoAttribute& positive_attr,
                const AlgoAttribute& negative_attr) {
    auto algo = static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo();
279
    algo->check_attribute(positive_attr, negative_attr);
280 281 282
    return algo;
}

M
Megvii Engine Team 已提交
283 284 285
ConvolutionBackwardData::Algorithm* ConvolutionBackwardDataImpl::
        get_algorithm_from_desc(const AlgorithmDesc& desc) {
    Algorithm* ret = static_cast<HandleImpl*>(handle())->default_conv_bwd_data_algo();
286 287 288 289
    megdnn_assert(desc == ret->info().desc);
    return ret;
}

M
Megvii Engine Team 已提交
290 291 292
std::vector<ConvolutionBackwardFilter::Algorithm*> ConvolutionBackwardFilterImpl::
        get_all_algorithms(
                const TensorLayout&, const TensorLayout&, const TensorLayout&) {
293 294 295
    return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()};
}

M
Megvii Engine Team 已提交
296 297 298
std::vector<ConvolutionBackwardFilter::Algorithm*> ConvolutionBackwardFilterImpl::
        get_all_algorithms_safe(
                const TensorLayout&, const TensorLayout&, const TensorLayout&) {
299 300 301
    return {static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo()};
}

M
Megvii Engine Team 已提交
302 303 304 305 306 307 308
ConvolutionBackwardFilter::Algorithm* ConvolutionBackwardFilterImpl::
        get_algorithm_heuristic(
                const TensorLayout& /* src */, const TensorLayout& /* diff */,
                const TensorLayout& /* grad */, size_t /* workspace_limit_in_bytes */,
                const AlgoAttribute& positive_attr,
                const AlgoAttribute& negative_attr) {
    auto algo = static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo();
309
    algo->check_attribute(positive_attr, negative_attr);
310 311 312
    return algo;
}

M
Megvii Engine Team 已提交
313 314 315
ConvolutionBackwardFilter::Algorithm* ConvolutionBackwardFilterImpl::
        get_algorithm_from_desc(const AlgorithmDesc& desc) {
    Algorithm* ret = static_cast<HandleImpl*>(handle())->default_conv_bwd_filter_algo();
316 317 318 319
    megdnn_assert(desc == ret->info().desc);
    return ret;
}

320 321 322 323 324 325 326 327 328 329 330 331 332
const char* ConvolutionForwardImpl::get_algorithm_set_name() const {
    return "DEFAULT";
}

const char* ConvolutionBackwardDataImpl::get_algorithm_set_name() const {
    return "DEFAULT";
}

const char* ConvolutionBackwardFilterImpl::get_algorithm_set_name() const {
    return "DEFAULT";
}

// vim: syntax=cpp.doxygen