layout_options.cpp 4.9 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171
/**
 * \file lite/load_and_run/src/options/layout_options.cpp
 *
 * This file is part of MegEngine, a deep learning framework developed by
 * Megvii.
 *
 * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
 */

#include <gflags/gflags.h>

#include "misc.h"
#include "models/model_lite.h"
#include "models/model_mdl.h"

#include "layout_options.h"
namespace lar {
template <>
void LayoutOption::config_model_internel<ModelLite>(
        RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
    if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
#define ENABLE_LAYOUT(layout)                           \
    LITE_WARN("enable " #layout " optimization");       \
    model->get_config().options.enable_##layout = true; \
    break;

        switch (option_flag) {
            case OptLayoutType::NCHW4:
                ENABLE_LAYOUT(nchw4)

            case OptLayoutType::CHWN4:
                LITE_THROW("lite model unsupport chwn4 layout");
                break;
            case OptLayoutType::NCHW44:
                ENABLE_LAYOUT(nchw44)

            case OptLayoutType::NCHW88:
                ENABLE_LAYOUT(nchw88)

            case OptLayoutType::NCHW32:
                ENABLE_LAYOUT(nchw32)

            case OptLayoutType::NCHW64:
                ENABLE_LAYOUT(nchw64)

            case OptLayoutType::NHWCD4:
                ENABLE_LAYOUT(nhwcd4)

            case OptLayoutType::NCHW44_DOT:
                ENABLE_LAYOUT(nchw44_dot)
            default:
                break;
        }
#undef ENABLE_LAYOUT
    }
}

template <>
void lar::LayoutOption::config_model_internel<ModelMdl>(
        RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
    if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
        mgb_log_debug("mdl  layout config start");
#define ENABLE_LAYOUT(layout)                                                  \
    mgb_log_warn("enable " #layout " optimization");                           \
    model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \
    break;

        switch (option_flag) {
            case OptLayoutType::NCHW4:
                ENABLE_LAYOUT(nchw4)

            case OptLayoutType::CHWN4:
                ENABLE_LAYOUT(chwn4)

            case OptLayoutType::NCHW44:
                ENABLE_LAYOUT(nchw44)

            case OptLayoutType::NCHW88:
                ENABLE_LAYOUT(nchw88)

            case OptLayoutType::NCHW32:
                ENABLE_LAYOUT(nchw32)

            case OptLayoutType::NCHW64:
                ENABLE_LAYOUT(nchw64)

            case OptLayoutType::NHWCD4:
                ENABLE_LAYOUT(nhwcd4)

            case OptLayoutType::NCHW44_DOT:
                ENABLE_LAYOUT(nchw44_dot)

            default:
                break;
        }
        mgb_log_debug("mdl layout config end");

#undef ENABLE_LAYOUT
    }
}
}  // namespace lar

using namespace lar;

OptLayoutType LayoutOption::option_flag;

LayoutOption::LayoutOption() {
    m_option_name = "layout";
}

bool LayoutOption::is_valid() {
    size_t valid_flag = 0;
    if (FLAGS_enable_nchw4) {
        valid_flag = valid_flag | (1 << 0);
    }
    if (FLAGS_enable_chwn4) {
        valid_flag = valid_flag | (1 << 1);
    }
    if (FLAGS_enable_nchw44) {
        valid_flag = valid_flag | (1 << 2);
    }
    if (FLAGS_enable_nchw88) {
        valid_flag = valid_flag | (1 << 3);
    }
    if (FLAGS_enable_nchw32) {
        valid_flag = valid_flag | (1 << 4);
    }
    if (FLAGS_enable_nchw64) {
        valid_flag = valid_flag | (1 << 5);
    }
    if (FLAGS_enable_nhwcd4) {
        valid_flag = valid_flag | (1 << 6);
    }
    if (FLAGS_enable_nchw44_dot) {
        valid_flag = valid_flag | (1 << 7);
    }

    bool ret = valid_flag && !(valid_flag & (valid_flag - 1));
    if (ret) {
        option_flag = static_cast<OptLayoutType>(valid_flag);
    } else {
        option_flag = static_cast<OptLayoutType>(0);
    }

    return ret;
};

std::shared_ptr<OptionBase> LayoutOption::create_option() {
    static std::shared_ptr<LayoutOption> option(new LayoutOption);
    if (LayoutOption::is_valid()) {
        return std::static_pointer_cast<OptionBase>(option);
    } else {
        return nullptr;
    }
}

void LayoutOption::config_model(
        RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
    CONFIG_MODEL_FUN;
}

DEFINE_bool(enable_nchw4, false, "enable nchw4 layout optimization!!");
DEFINE_bool(enable_chwn4, false, "enable chwn4 layout optimization!!");
DEFINE_bool(enable_nchw44, false, "enable nchw44 layout optimization!!");
DEFINE_bool(enable_nchw88, false, "enable nchw88 layout optimization!!");
DEFINE_bool(enable_nchw32, false, "enable nchw32 layout optimization!!");
DEFINE_bool(enable_nchw64, false, "enable nchw64 layout optimization!!");
DEFINE_bool(enable_nhwcd4, false, "enable nhwcd4 layout optimization!!");
DEFINE_bool(enable_nchw44_dot, false, "enable nchw444-dot layout optimization!!");

REGIST_OPTION_CREATOR(layout, lar::LayoutOption::create_option);