device_options.cpp 10.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#include <iostream>
#include <sstream>
#include "lite/global.h"
#include "megbrain/comp_node_env.h"
#include "misc.h"
#include "device_options.h"
#include "models/model_lite.h"
#include "models/model_mdl.h"

DECLARE_bool(weight_preprocess);

using namespace lar;

/////////////////// XPUDeviceOption //////////////////////
namespace lar {
template <>
void XPUDeviceOption::config_model_internel<ModelLite>(
        RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
    if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
        if ((enable_cpu) || (enable_cpu_default) || (enable_multithread) ||
            (enable_multithread_default)) {
22
            LITE_LOG("using cpu device\n");
23 24
            model->get_config().device_type = LiteDeviceType::LITE_CPU;
        }
25
#if LITE_WITH_CUDA
26
        if (enable_cuda) {
27
            LITE_LOG("using cuda device\n");
28 29 30
            model->get_config().device_type = LiteDeviceType::LITE_CUDA;
        }
#endif
31
    } else if (runtime_param.stage == RunStage::AFTER_NETWORK_CREATED) {
32
        auto&& network = model->get_lite_network();
33
        if (enable_cpu_default) {
34
            LITE_LOG("using cpu default device\n");
35 36 37
            lite::Runtime::set_cpu_inplace_mode(network);
        }
        if (enable_multithread) {
38
            LITE_LOG("using multithread device\n");
39 40 41
            lite::Runtime::set_cpu_threads_number(network, thread_num);
        }
        if (enable_multithread_default) {
42
            LITE_LOG("using multithread  default device\n");
43 44 45 46 47 48 49 50
            lite::Runtime::set_cpu_inplace_mode(network);
            lite::Runtime::set_cpu_threads_number(network, thread_num);
        }
        if (enable_set_core_ids) {
            std::string core_str;
            for (auto id : core_ids) {
                core_str += std::to_string(id) + ",";
            }
51
            LITE_LOG("multi thread core ids: %s\n", core_str.c_str());
52 53 54 55 56 57 58 59 60 61 62 63 64
            lite::ThreadAffinityCallback affinity_callback = [&](size_t thread_id) {
                mgb::sys::set_cpu_affinity({core_ids[thread_id]});
            };
            lite::Runtime::set_runtime_thread_affinity(network, affinity_callback);
        }
    }
}

template <>
void XPUDeviceOption::config_model_internel<ModelMdl>(
        RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
    if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
        if (enable_cpu) {
65
            mgb_log("using cpu device\n");
66 67 68 69
            model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) {
                loc.type = mgb::CompNode::DeviceType::CPU;
            };
        }
70
#if LITE_WITH_CUDA
71
        if (enable_cuda) {
72
            mgb_log("using cuda device\n");
73
            model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) {
74 75 76
                if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
                    loc.type = mgb::CompNode::DeviceType::CUDA;
                }
77
                loc.device = 0;
78 79 80 81
            };
        }
#endif
        if (enable_cpu_default) {
82
            mgb_log("using cpu default device\n");
83 84 85 86 87 88
            model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) {
                loc.type = mgb::CompNode::DeviceType::CPU;
                loc.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
            };
        }
        if (enable_multithread) {
89
            mgb_log("using multithread(threads number:%ld) device\n", thread_num);
90 91 92 93 94 95 96 97
            model->get_mdl_config().comp_node_mapper =
                    [&](mgb::CompNode::Locator& loc) {
                        loc.type = mgb::CompNode::DeviceType::MULTITHREAD;
                        loc.device = 0;
                        loc.stream = thread_num;
                    };
        }
        if (enable_multithread_default) {
98
            mgb_log("using multithread default device\n");
99 100 101 102 103 104 105 106 107 108 109 110
            model->get_mdl_config().comp_node_mapper =
                    [&](mgb::CompNode::Locator& loc) {
                        loc.type = mgb::CompNode::DeviceType::MULTITHREAD;
                        loc.device = mgb::CompNode::Locator::DEVICE_MULTITHREAD_DEFAULT;
                        loc.stream = thread_num;
                    };
        }
        if (enable_set_core_ids) {
            std::string core_str;
            for (auto id : core_ids) {
                core_str += std::to_string(id) + ",";
            }
111
            mgb_log("set multi thread core ids:%s\n", core_str.c_str());
112 113 114 115 116 117 118 119 120 121 122 123 124
            auto affinity_callback = [&](size_t thread_id) {
                mgb::sys::set_cpu_affinity({core_ids[thread_id]});
            };
            mgb::CompNode::Locator loc;
            model->get_mdl_config().comp_node_mapper(loc);
            auto comp_node = mgb::CompNode::load(loc);
            mgb::CompNodeEnv::from_comp_node(comp_node).cpu_env().set_affinity(
                    affinity_callback);
        }
    }
}
}  // namespace lar

125
void XPUDeviceOption::update() {
126 127
    m_option_name = "xpu_device";
    enable_cpu = FLAGS_cpu;
128
#if LITE_WITH_CUDA
129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
    enable_cuda = FLAGS_cuda;
#endif
    enable_cpu_default = FLAGS_cpu_default;

    if (FLAGS_multithread >= 0) {
        thread_num = FLAGS_multithread;
        enable_multithread = true;
    }

    if (FLAGS_multithread_default >= 0) {
        thread_num = FLAGS_multithread_default;
        enable_multithread_default = true;
    }

    if (!FLAGS_multi_thread_core_ids.empty()) {
144 145 146
        mgb_assert(
                enable_multithread || enable_multithread_default,
                "core ids should be set after --multithread or --multithread-default");
147 148 149 150 151 152 153 154 155 156 157
        std::stringstream id_stream(FLAGS_multi_thread_core_ids);
        std::string id;
        size_t thread_cnt = 0;
        while (getline(id_stream, id, ',')) {
            thread_cnt++;
            core_ids.push_back(atoi(id.c_str()));
        }
        mgb_assert(
                thread_cnt == thread_num,
                "core ids number should be same with thread number set before");
        enable_set_core_ids = true;
158 159
    } else {
        enable_set_core_ids = false;
160 161
    }

162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185
    m_option = {
        {"cpu", lar::Bool::make(false)},
#if LITE_WITH_CUDA
        {"cuda", lar::Bool::make(false)},
#endif
        {"cpu_default", lar::Bool::make(false)},
        {"multithread", lar::NumberInt32::make(-1)},
        {"multithread_default", lar::NumberInt32::make(-1)},
        {"multi_thread_core_ids", lar::String::make("")},
    };
    std::static_pointer_cast<lar::Bool>(m_option["cpu"])->set_value(FLAGS_cpu);
#if LITE_WITH_CUDA
    std::static_pointer_cast<lar::Bool>(m_option["cuda"])->set_value(FLAGS_cuda);
#endif
    std::static_pointer_cast<lar::Bool>(m_option["cpu_default"])
            ->set_value(FLAGS_cpu_default);
    std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"])
            ->set_value(FLAGS_multithread);
    std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"])
            ->set_value(FLAGS_multithread_default);
    std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"])
            ->set_value(FLAGS_multi_thread_core_ids);
}
bool XPUDeviceOption::m_valid;
186 187
bool XPUDeviceOption::is_valid() {
    bool ret = FLAGS_cpu || FLAGS_cpu_default;
188
#if LITE_WITH_CUDA
189 190 191 192 193 194
    ret = ret || FLAGS_cuda;
#endif
    ret = ret || FLAGS_multithread >= 0;
    ret = ret || FLAGS_multithread_default >= 0;
    ret = ret || !FLAGS_multi_thread_core_ids.empty();

195
    return ret || m_valid;
196 197 198 199 200
}

std::shared_ptr<OptionBase> XPUDeviceOption::create_option() {
    static std::shared_ptr<lar::XPUDeviceOption> option(new XPUDeviceOption);
    if (XPUDeviceOption::is_valid()) {
201
        option->update();
202 203 204 205 206 207 208 209
        return std::static_pointer_cast<lar::OptionBase>(option);
    } else {
        return nullptr;
    }
}

void XPUDeviceOption::config_model(
        RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
210 211 212 213 214 215 216 217 218 219
    enable_cpu = std::static_pointer_cast<lar::Bool>(m_option["cpu"])->get_value();
#if LITE_WITH_CUDA
    enable_cuda = std::static_pointer_cast<lar::Bool>(m_option["cuda"])->get_value();
#endif
    enable_cpu_default =
            std::static_pointer_cast<lar::Bool>(m_option["cpu_default"])->get_value();
    int32_t num_of_thread =
            std::static_pointer_cast<lar::NumberInt32>(m_option["multithread"])
                    ->get_value();
    enable_multithread = num_of_thread >= 0;
220
    int32_t num_of_thread_dft =
221 222
            std::static_pointer_cast<lar::NumberInt32>(m_option["multithread_default"])
                    ->get_value();
223 224 225 226 227 228
    enable_multithread_default = num_of_thread_dft >= 0;
    mgb_assert(
            num_of_thread < 0 || num_of_thread_dft < 0,
            "multithread and multithread_default should not bet set at the same time");
    thread_num = num_of_thread >= 0 ? num_of_thread
                                    : (num_of_thread_dft >= 0 ? num_of_thread_dft : -1);
229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246
    std::string core_id_str =
            std::static_pointer_cast<lar::String>(m_option["multi_thread_core_ids"])
                    ->get_value();
    if (!core_id_str.empty()) {
        mgb_assert(
                enable_multithread || enable_multithread_default,
                "core ids should be set after --multithread or --multithread-default");
        std::stringstream id_stream(core_id_str);
        std::string id;
        size_t thread_cnt = 0;
        while (getline(id_stream, id, ',')) {
            thread_cnt++;
            core_ids.push_back(atoi(id.c_str()));
        }
        mgb_assert(
                thread_cnt == thread_num,
                "core ids number should be same with thread number set before");
        enable_set_core_ids = true;
247 248
    } else {
        enable_set_core_ids = false;
249 250
    }

251 252 253 254
    CONFIG_MODEL_FUN;
}
///////////////////////// xpu gflags ////////////////////////////
DEFINE_bool(cpu, false, "set CPU device as running device");
255
#if LITE_WITH_CUDA
256 257 258 259 260 261 262 263
DEFINE_bool(cuda, false, "set CUDA device as running device ");
#endif
DEFINE_bool(cpu_default, false, "set running device as CPU device with inplace mode");
DEFINE_int32(multithread, -1, "set multithread device as running device");
DEFINE_int32(
        multithread_default, -1,
        "set multithread device as running device with inplace mode");
DEFINE_string(multi_thread_core_ids, "", "set multithread core id");
264
REGIST_OPTION_CREATOR(xpu_device, lar::XPUDeviceOption::create_option);
265
REGIST_OPTION_VALIDATER(xpu_device, lar::XPUDeviceOption::set_valid);