opr_impl.cpp 5.9 KB
Newer Older
1 2 3 4 5 6 7
#include "src/naive/batch_conv_bias/opr_impl.h"
#include "megdnn/oprs/nn.h"
#include "src/common/conv_bias.h"
#include "src/naive/conv_bias/opr_impl.h"
#include "src/naive/convolution/helper.h"

#include <cstring>
8
#include "megdnn/algorithm_cache.h"
9 10 11 12 13 14 15 16 17 18
#include "src/common/utils.h"
#include "src/naive/handle.h"

using namespace megdnn;
using namespace naive;
using namespace convolution;

namespace {
struct BatchConvFilterVisitor {
    template <typename ftype>
M
Megvii Engine Team 已提交
19 20 21
    static ftype* get_current_ptr(
            ftype* fptr, size_t batch, size_t /* oc */, size_t /* oh */,
            size_t /* ow */, size_t filter_sizes) {
22 23 24 25 26 27
        return fptr + batch * filter_sizes;
    }
};
}  // namespace

WorkspaceBundle BatchConvBiasForwardImpl::get_workspace_bundle(
M
Megvii Engine Team 已提交
28 29
        dt_byte* raw_ptr, const TensorLayout& /* src */, const TensorLayout& /* flt */,
        const TensorLayout& bias, const TensorLayout& z, const TensorLayout& dst) {
30 31 32 33 34 35 36 37
    size_t ws_bias_size = 0, ws_z_size = 0;
    if (bias.dtype.enumv() != dst.dtype.enumv()) {
        ws_z_size = TensorLayout{dst, bias.dtype}.span().dist_byte();
    }
    if (z.ndim > 0) {
        megdnn_assert(z.dtype.enumv() == DTypeEnum::QuantizedS8);
        megdnn_assert(z.eq_shape(dst));
        // (w * f + b).astype(float) + (z).astype(float)
M
Megvii Engine Team 已提交
38
        size_t f32_z_size = TensorLayout{z, dtype::Float32()}.span().dist_byte();
39 40 41 42 43 44
        ws_z_size = f32_z_size + f32_z_size;
    }
    return WorkspaceBundle{raw_ptr, {ws_bias_size, ws_z_size}};
}

size_t BatchConvBiasForwardImpl::get_workspace_in_bytes(
M
Megvii Engine Team 已提交
45 46 47
        const TensorLayout& src, const TensorLayout& flt, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst) {
    return get_workspace_bundle(nullptr, src, flt, bias, z, dst).total_size_in_bytes();
48 49
}

M
Megvii Engine Team 已提交
50 51 52 53 54 55 56 57 58
void BatchConvBiasForwardImpl::exec(
        _megdnn_tensor_in src, _megdnn_tensor_in filter, _megdnn_tensor_in bias,
        _megdnn_tensor_in z, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
    auto filter_meta = check_exec(
            src.layout, filter.layout, bias.layout, z.layout, dst.layout,
            workspace.size);
    WorkspaceBundle ws = get_workspace_bundle(
            workspace.raw_ptr, src.layout, filter.layout, bias.layout, z.layout,
            dst.layout);
59 60 61 62 63
    auto sfb = dst;
    if (bias.layout.dtype.enumv() != dst.layout.dtype.enumv()) {
        sfb = TensorND{ws.get(0), TensorLayout{dst.layout, bias.layout.dtype}};
    }

M
Megvii Engine Team 已提交
64 65 66 67 68 69 70 71 72
#define DISPATCH_RAW(in_dt, bias_dt, out_dt, cmode, func)                     \
    else if (                                                                 \
            src.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv &&    \
            filter.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv && \
            bias.layout.dtype.enumv() == DTypeTrait<dtype::bias_dt>::enumv && \
            sfb.layout.dtype.enumv() == DTypeTrait<dtype::out_dt>::enumv &&   \
            param().compute_mode == Param::ComputeMode::cmode) {              \
        MEGDNN_DISPATCH_CPU_KERN_OPR(                                         \
                func(src, filter, bias, sfb, nullptr, filter_meta));          \
73
    }
M
Megvii Engine Team 已提交
74 75 76 77 78 79 80 81 82
#define DISPATCH(in_dt, out_dt)                                                       \
    DISPATCH_RAW(                                                                     \
            in_dt, out_dt, out_dt, DEFAULT,                                           \
            (forward_bias<                                                            \
                    DTypeTrait<dtype::in_dt>::ctype, DTypeTrait<dtype::in_dt>::ctype, \
                    DTypeTrait<dtype::out_dt>::ctype,                                 \
                    DTypeTrait<dtype::out_dt>::ctype,                                 \
                    BatchConvBiasForward::CanonizedFilterMeta,                        \
                    BatchConvFilterVisitor>))
83 84 85 86 87 88 89 90 91 92 93 94
    if (0) {
    }
    DISPATCH(QuantizedS8, QuantizedS32)
    else {
        megdnn_throw(ssprintf(
                "unsupported naive BatchConvBias(%s, %s, %s, %s) -> %s",
                src.layout.dtype.name(), filter.layout.dtype.name(),
                bias.layout.dtype.name(), z.layout.dtype.name(),
                dst.layout.dtype.name()));
    }
#undef DISPATCH
#undef DISPATCH_RAW
95
    MEGDNN_DISPATCH_CPU_KERN_OPR(handle_z_inp_and_activation_naive(
M
Megvii Engine Team 已提交
96
            param().nonlineMode, sfb, z, dst, reinterpret_cast<dt_byte*>(ws.get(1))));
97 98
}

M
Megvii Engine Team 已提交
99 100 101 102 103
std::vector<BatchConvBiasForward::Algorithm*> BatchConvBiasForwardImpl::
        get_all_algorithms(
                const TensorLayout&, const TensorLayout&, const TensorLayout&,
                const TensorLayout&, const TensorLayout&) {
    return {static_cast<HandleImpl*>(handle())->default_batch_conv_bias_fwd_algo()};
104 105
}

M
Megvii Engine Team 已提交
106 107 108 109 110
std::vector<BatchConvBiasForward::Algorithm*> BatchConvBiasForwardImpl::
        get_all_algorithms_safe(
                const TensorLayout&, const TensorLayout&, const TensorLayout&,
                const TensorLayout&, const TensorLayout&) {
    return {static_cast<HandleImpl*>(handle())->default_batch_conv_bias_fwd_algo()};
111 112
}

M
Megvii Engine Team 已提交
113
BatchConvBiasForward::Algorithm* BatchConvBiasForwardImpl::get_algorithm_heuristic(
114 115 116 117
        const TensorLayout& /* src */, const TensorLayout& /* filter */,
        const TensorLayout& /* bias */, const TensorLayout& /* z */,
        const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */
        ,
M
Megvii Engine Team 已提交
118 119
        const AlgoAttribute& positive_attr, const AlgoAttribute& negative_attr) {
    auto algo = static_cast<HandleImpl*>(handle())->default_batch_conv_bias_fwd_algo();
120
    algo->check_attribute(positive_attr, negative_attr);
121 122 123
    return algo;
}

M
Megvii Engine Team 已提交
124 125 126 127
BatchConvBiasForward::Algorithm* BatchConvBiasForwardImpl::get_algorithm_from_desc(
        const AlgorithmDesc& desc) {
    Algorithm* ret =
            static_cast<HandleImpl*>(handle())->default_batch_conv_bias_fwd_algo();
128 129 130 131
    megdnn_assert(desc == ret->info().desc);
    return ret;
}

132
// vim: syntax=cpp.doxygen