opr_impl.cpp 6.3 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/naive/batch_conv_bias/opr_impl.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */
#include "src/naive/batch_conv_bias/opr_impl.h"
#include "megdnn/oprs/nn.h"
#include "src/common/conv_bias.h"
#include "src/naive/conv_bias/opr_impl.h"
#include "src/naive/convolution/helper.h"

#include <cstring>
#include "src/common/utils.h"
#include "src/naive/handle.h"

using namespace megdnn;
using namespace naive;
using namespace convolution;

namespace {
struct BatchConvFilterVisitor {
    template <typename ftype>
    static ftype* get_current_ptr(ftype* fptr, size_t batch, size_t /* oc */,
                                  size_t /* oh */, size_t /* ow */,
                                  size_t filter_sizes) {
        return fptr + batch * filter_sizes;
    }
};
}  // namespace

WorkspaceBundle BatchConvBiasForwardImpl::get_workspace_bundle(
        dt_byte* raw_ptr, const TensorLayout& /* src */,
        const TensorLayout& /* flt */, const TensorLayout& bias,
        const TensorLayout& z, const TensorLayout& dst) {
    size_t ws_bias_size = 0, ws_z_size = 0;
    if (bias.dtype.enumv() != dst.dtype.enumv()) {
        ws_z_size = TensorLayout{dst, bias.dtype}.span().dist_byte();
    }
    if (z.ndim > 0) {
        megdnn_assert(z.dtype.enumv() == DTypeEnum::QuantizedS8);
        megdnn_assert(z.eq_shape(dst));
        // (w * f + b).astype(float) + (z).astype(float)
        size_t f32_z_size =
                TensorLayout{z, dtype::Float32()}.span().dist_byte();
        ws_z_size = f32_z_size + f32_z_size;
    }
    return WorkspaceBundle{raw_ptr, {ws_bias_size, ws_z_size}};
}

size_t BatchConvBiasForwardImpl::get_workspace_in_bytes(
        const TensorLayout& src, const TensorLayout& flt,
        const TensorLayout& bias, const TensorLayout& z,
        const TensorLayout& dst) {
    return get_workspace_bundle(nullptr, src, flt, bias, z, dst)
            .total_size_in_bytes();
}

void BatchConvBiasForwardImpl::exec(_megdnn_tensor_in src,
                                    _megdnn_tensor_in filter,
                                    _megdnn_tensor_in bias, _megdnn_tensor_in z,
                                    _megdnn_tensor_out dst,
                                    _megdnn_workspace workspace) {
    auto filter_meta = check_exec(src.layout, filter.layout, bias.layout,
                                  z.layout, dst.layout, workspace.size);
    WorkspaceBundle ws =
            get_workspace_bundle(workspace.raw_ptr, src.layout, filter.layout,
                                 bias.layout, z.layout, dst.layout);
    auto sfb = dst;
    if (bias.layout.dtype.enumv() != dst.layout.dtype.enumv()) {
        sfb = TensorND{ws.get(0), TensorLayout{dst.layout, bias.layout.dtype}};
    }

#define DISPATCH_RAW(in_dt, bias_dt, out_dt, cmode, func)                      \
    else if (src.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv &&    \
             filter.layout.dtype.enumv() == DTypeTrait<dtype::in_dt>::enumv && \
             bias.layout.dtype.enumv() == DTypeTrait<dtype::bias_dt>::enumv && \
             sfb.layout.dtype.enumv() == DTypeTrait<dtype::out_dt>::enumv &&   \
             param().compute_mode == Param::ComputeMode::cmode) {              \
        MEGDNN_DISPATCH_CPU_KERN_OPR(                                          \
                func(src, filter, bias, sfb, nullptr, filter_meta));           \
    }
#define DISPATCH(in_dt, out_dt)                                           \
    DISPATCH_RAW(in_dt, out_dt, out_dt, DEFAULT,                          \
                 (forward_bias<DTypeTrait<dtype::in_dt>::ctype,           \
                               DTypeTrait<dtype::in_dt>::ctype,           \
                               DTypeTrait<dtype::out_dt>::ctype,          \
                               DTypeTrait<dtype::out_dt>::ctype,          \
                               BatchConvBiasForward::CanonizedFilterMeta, \
                               BatchConvFilterVisitor>))
    if (0) {
    }
    DISPATCH(QuantizedS8, QuantizedS32)
    else {
        megdnn_throw(ssprintf(
                "unsupported naive BatchConvBias(%s, %s, %s, %s) -> %s",
                src.layout.dtype.name(), filter.layout.dtype.name(),
                bias.layout.dtype.name(), z.layout.dtype.name(),
                dst.layout.dtype.name()));
    }
#undef DISPATCH
#undef DISPATCH_RAW
107 108 109
    MEGDNN_DISPATCH_CPU_KERN_OPR(handle_z_inp_and_activation_naive(
            param().nonlineMode, sfb, z, dst,
            reinterpret_cast<dt_byte*>(ws.get(1))));
110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
}

std::vector<BatchConvBiasForward::Algorithm*>
BatchConvBiasForwardImpl::get_all_algorithms(const TensorLayout&,
                                             const TensorLayout&,
                                             const TensorLayout&,
                                             const TensorLayout&,
                                             const TensorLayout&) {
    return {static_cast<HandleImpl*>(handle())
                    ->default_batch_conv_bias_fwd_algo()};
}

BatchConvBiasForward::Algorithm*
BatchConvBiasForwardImpl::get_algorithm_heuristic(
        const TensorLayout& /* src */, const TensorLayout& /* filter */,
        const TensorLayout& /* bias */, const TensorLayout& /* z */,
        const TensorLayout& /* dst */, size_t /* workspace_limit_in_bytes */
        ,
128
        const AlgoAttribute& attr) {
129 130
    auto algo = static_cast<HandleImpl*>(handle())
            ->default_batch_conv_bias_fwd_algo();
131 132 133 134 135
    megdnn_assert(algo->contain_attribute(attr),
                  "require algorithm with attribute%s, but heuristic "
                  "algorithm(%s) with attribute%s ",
                  Algorithm::attribute_str(attr).c_str(), algo->name(),
                  Algorithm::attribute_str(algo->attribute()).c_str());
136 137 138
    return algo;
}

139 140 141 142 143 144 145 146
BatchConvBiasForward::Algorithm*
BatchConvBiasForwardImpl::get_algorithm_from_desc(const AlgorithmDesc& desc) {
    Algorithm* ret = static_cast<HandleImpl*>(handle())
                             ->default_batch_conv_bias_fwd_algo();
    megdnn_assert(desc == ret->info().desc);
    return ret;
}

147
// vim: syntax=cpp.doxygen