winograd_algo.cpp 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43
/**
 * \file dnn/src/x86/conv_bias/f32/winograd_algo.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2020 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 */

#include "src/x86/conv_bias/f32/algos.h"
#include "src/common/utils.h"
#include "src/x86/conv_bias/opr_impl.h"
#include "src/x86/conv_bias/postprocess_helper.h"
#include "src/x86/handle.h"
#include "src/x86/profile.h"
#include "src/x86/conv_bias/f32/strategy.h"

#include "midout.h"

MIDOUT_DECL(megdnn_x86_winograd_fp32)

using namespace megdnn;
using namespace x86;

/* ======================= AlgoFP32WinogradF63_8*8 ======================== */

bool ConvBiasImpl::AlgoFP32WinogradF63_8x8::usable(
        fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
        AlgoSelectionStrategy /*algo_selection_strategy*/) const {
    MEGDNN_MARK_USED_VAR(param);
    MEGDNN_MARK_USED_VAR(opr);
    MIDOUT_BEGIN(megdnn_x86_winograd_fp32, 1, 0) {
        //! TODO: now nchw88 winograd only support Dense mode
        if (param.filter_meta.icpg % 8 != 0 ||
            param.filter_meta.ocpg % 8 != 0 || param.filter_meta.group != 1)
            return false;
        using Strategy = winograd::winograd_nchw88_6x3_8x8_f;
        Strategy strategy(param.src_type, param.filter_type, param.dst_type);
        auto&& matmul_param =
                megdnn::winograd::ConvBias<Strategy,
                                           param::MatrixMul::Format::MK8>(
44
                        strategy, m_tile_size, param)
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
                        .get_matmul_kern_param(param);
        return m_matmul_algo->usable(matmul_param) &&
               (opr->param().format == param::ConvBias::Format::NCHW88 ||
                (opr->param().format ==
                         param::ConvBias::Format::NCHW88_WINOGRAD &&
                 opr->param().output_block_size == 6 &&
                 param.winograd_matmul_format ==
                         param::MatrixMul::Format::MK8)) &&
               opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION &&
               (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
                param.filter_meta.spatial[0] == 3) &&
               (param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
                param.filter_meta.stride[0] == 1) &&
               (param.filter_meta.dilation[0] ==
                        param.filter_meta.dilation[1] &&
                param.filter_meta.dilation[0] == 1) &&
               param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
               param.src_type.enumv() == DTypeEnum::Float32 &&
               is_supported(SIMDType::AVX2);
    }
    MIDOUT_END();
    return false;
}

69 70 71 72
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF63_8x8,
                                    winograd::winograd_nchw88_6x3_8x8_f,
                                    megdnn_x86_winograd_fp32,
                                    param::MatrixMul::Format::MK8);
73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90

/* ======================= AlgoFP32WinogradF23_8*8 ======================== */

bool ConvBiasImpl::AlgoFP32WinogradF23_8x8::usable(
        fallback::ConvBiasImpl* opr, const NCBKernSizeParam& param,
        AlgoSelectionStrategy /*algo_selection_strategy*/) const {
    MEGDNN_MARK_USED_VAR(param);
    MEGDNN_MARK_USED_VAR(opr);
    MIDOUT_BEGIN(megdnn_x86_winograd_fp32, 2, 0) {
        //! TODO: now nchw88 winograd only support Dense mode
        if (param.filter_meta.icpg % 8 != 0 ||
            param.filter_meta.ocpg % 8 != 0 || param.filter_meta.group != 1)
            return false;
        using Strategy = winograd::winograd_nchw88_2x3_8x8_f;
        Strategy strategy(param.src_type, param.filter_type, param.dst_type);
        auto&& matmul_param =
                megdnn::winograd::ConvBias<Strategy,
                                           param::MatrixMul::Format::MK8>(
91
                        strategy, m_tile_size, param)
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115
                        .get_matmul_kern_param(param);
        return m_matmul_algo->usable(matmul_param) &&
               (opr->param().format == param::ConvBias::Format::NCHW88 ||
                (opr->param().format ==
                         param::ConvBias::Format::NCHW88_WINOGRAD &&
                 opr->param().output_block_size == 2 &&
                 param.winograd_matmul_format ==
                         param::MatrixMul::Format::MK8)) &&
               opr->param().mode == param::ConvBias::Mode::CROSS_CORRELATION &&
               (param.filter_meta.spatial[0] == param.filter_meta.spatial[1] &&
                param.filter_meta.spatial[0] == 3) &&
               (param.filter_meta.stride[0] == param.filter_meta.stride[1] &&
                param.filter_meta.stride[0] == 1) &&
               (param.filter_meta.dilation[0] ==
                        param.filter_meta.dilation[1] &&
                param.filter_meta.dilation[0] == 1) &&
               param.compute_mode == param::ConvBias::ComputeMode::DEFAULT &&
               param.src_type.enumv() == DTypeEnum::Float32 &&
               is_supported(SIMDType::AVX2);
    }
    MIDOUT_END();
    return false;
}

116 117 118 119
MEGDNN_WINOGRAD_ALGO_FUN_DEFINE_ALL(AlgoFP32WinogradF23_8x8,
                                    winograd::winograd_nchw88_2x3_8x8_f,
                                    megdnn_x86_winograd_fp32,
                                    param::MatrixMul::Format::MK8);
120 121

// vim: syntax=cpp.doxygen