batched_matmul.cpp 6.1 KB
Newer Older
1 2 3 4
/**
 * \file dnn/src/cuda/conv_bias/batched_matmul.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6 7 8
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
9 10
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
11 12
 */

13
#include "src/common/algo_chooser.h"
14
#include "src/common/algo_base.h"
15
#include "src/common/conv_bias.h"
16
#include "src/cuda/batched_matrix_mul/algo.h"
17 18 19 20 21 22 23 24
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.cuh"

using namespace megdnn;
using namespace cuda;
using namespace conv_bias;

25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
namespace {
std::pair<TensorLayoutArray, MatrixMulForward::Param> sub_opr_config(
        const ConvBiasForwardImpl::CanonizedFilterMeta& fm,
        const TensorLayout& src_layout, const TensorLayout&,
        const TensorLayout& dst_layout, const ConvBiasForwardImpl* opr) {
    // A {N, OC, IC}
    // B {N, IC, H * W}
    // C {N, OC, H * W}
    size_t batched = src_layout.shape[0];
    TensorLayout A, B, C;
    A = {{batched, fm.ocpg, fm.icpg}, fm.dtype};
    A.stride[0] = 0;
    B.ndim = 3;
    B.shape[1] = src_layout.shape[1];
    B.shape[2] = src_layout.shape[2] * src_layout.shape[3];
    B.shape[0] = batched;
    B.stride[2] = 1;
    B.stride[1] = src_layout.stride[1];
    B.stride[0] = src_layout.stride[0];
    B.dtype = src_layout.dtype;
    C = {{dst_layout.shape[0], dst_layout.shape[1], B.shape[2]},
         dst_layout.dtype};

    MatrixMulForward::Param param;
    if (opr->param().compute_mode == param::Convolution::ComputeMode::FLOAT32) {
        param.compute_mode = param::MatrixMul::ComputeMode::FLOAT32;
    }

    return {{A, B, C}, param};
}
55 56 57 58 59 60 61 62 63 64 65 66 67

std::pair<TensorLayoutArray, std::unique_ptr<BatchedMatrixMulForward>>
prepare_sub_opr(const ConvBiasForwardImpl::AlgoBase::SizeArgs& args) {
    auto bmatmul_opr = args.handle->create_operator<BatchedMatrixMulForward>();
    set_execution_policy<ConvBiasForward, BatchedMatrixMulForward*>(
            args.opr, bmatmul_opr.get());
    auto&& config =
            sub_opr_config(args.filter_meta, *args.src_layout,
                           *args.filter_layout, *args.dst_layout, args.opr);
    bmatmul_opr->param() = config.second;

    return {config.first, std::move(bmatmul_opr)};
}
68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85
}  // namespace

std::vector<Algorithm::SearchItem>
ConvBiasForwardImpl::AlgoBatchedMatmul::get_subopr_list(
        const TensorLayoutArray& layouts, const OperatorBase* opr) const {
    const ConvBiasForwardImpl* conv_bias_opr =
            static_cast<const ConvBiasForwardImpl*>(opr);
    CanonizedFilterMeta fm =
            conv_bias_opr->check_layout_fwd(layouts[0], layouts[1], layouts[4]);
    auto&& config = sub_opr_config(fm, layouts[0], layouts[1], layouts[4],
                                   conv_bias_opr);

    std::string param_str;
    Algorithm::serialize_write_pod(config.second, param_str);
    return {{Algorithm::OprType::BATCHED_MATRIX_MUL_FORWARD, param_str,
             config.first}};
}

86 87 88 89 90
bool ConvBiasForwardImpl::AlgoBatchedMatmul::is_available(
        const SizeArgs& args) const {
    if (args.z_layout->ndim > 0)
        return false;

91
    auto config = prepare_sub_opr(args);
92

93 94 95 96 97 98 99
    auto&& fm = args.filter_meta;
    return fm.format == Param::Format::NCHW &&
           (fm.dtype.enumv() == DTypeEnum::Float32 ||
            fm.dtype.enumv() == DTypeEnum::Float16) &&
           fm.spatial_ndim == 2 && fm.group == 1 && fm.dilation[0] == 1 &&
           fm.dilation[1] == 1 && fm.spatial[0] == 1 && fm.spatial[1] == 1 &&
           fm.padding[0] == 0 && fm.padding[1] == 0 && fm.stride[0] == 1 &&
100
           fm.stride[1] == 1 &&
101 102 103
           get_algorithm(static_cast<BatchedMatrixMulForwardImpl*>(
                                 config.second.get()),
                         config.first[0], config.first[1], config.first[2]);
104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
}

WorkspaceBundle ConvBiasForwardImpl::AlgoBatchedMatmul::get_workspace_bundle(
        void* ptr, const SizeArgs& args) const {
    auto dst_layout = *args.dst_layout;
    SmallVector<size_t> sizes;
    if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
        dst_layout.dtype = DType();
        args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype,
                                            args.filter_layout->dtype,
                                            dst_layout.dtype);
        sizes.push_back(dst_layout.span().dist_byte());
    }

    SizeArgs conv_args = args;
    conv_args.dst_layout = &dst_layout;
120

121
    auto config = prepare_sub_opr(args);
122 123

    sizes.insert(sizes.begin(),
124
                 config.second->get_workspace_in_bytes(
125
                         config.first[0], config.first[1], config.first[2]));
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
    return {ptr, std::move(sizes)};
}

size_t ConvBiasForwardImpl::AlgoBatchedMatmul::get_workspace_in_bytes(
        const SizeArgs& args) const {
    return get_workspace_bundle(nullptr, args).total_size_in_bytes();
}

void ConvBiasForwardImpl::AlgoBatchedMatmul::exec(const ExecArgs& args) const {
    auto bundle = get_workspace_bundle(args.workspace.raw_ptr, args);
    auto conv_dst_tensor = *args.dst_tensor;
    if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
        conv_dst_tensor.raw_ptr = bundle.get(1);
        conv_dst_tensor.layout.dtype = DType();
        args.opr->check_or_deduce_dtype_fwd(args.src_layout->dtype,
                                            args.filter_layout->dtype,
                                            conv_dst_tensor.layout.dtype);
    }

    ExecArgs conv_args = args;
    conv_args.dst_tensor = &conv_dst_tensor;
    conv_args.dst_layout = &conv_dst_tensor.layout;
    {
149
        auto config = prepare_sub_opr(args);
150 151 152 153

        TensorND A{args.filter_tensor->raw_ptr, config.first[0]},
                B{args.src_tensor->raw_ptr, config.first[1]},
                C{args.dst_tensor->raw_ptr, config.first[2]};
154
        config.second->exec(A, B, C, bundle.get_workspace(0));
155 156 157 158 159 160 161
    }
    handle_bias_and_nonlinear(args.handle, args.nonlinear_mode,
                              &conv_dst_tensor, args.dst_tensor,
                              args.bias_tensor);
}

// vim: syntax=cpp.doxygen