提交 ff239c63 编写于 作者: M Megvii Engine Team

feat(lite): add unit test for lar

GitOrigin-RevId: f3ba1e8a9ae3618f06cea32111a7d40bc0dec353
上级 7bf1c38c
load("//brain/megbrain/lite:flags.bzl","pthread_select", "lite_opts")
cc_library(
name = "mgblar",
name = "lar_object",
srcs = glob(["src/**/*.cpp"], exclude = ["src/main.cpp"]),
hdrs = glob(["src/**/*.h"]),
includes = ["src"],
......@@ -28,7 +28,8 @@ cc_megvii_binary(
"no_exceptions",
"no_rtti",
]),
internal_deps = [":mgblar"],
internal_deps = [":lar_object"],
visibility = ["//visibility:public"],
)
# BUILD the load and run for lite
include_directories(PUBLIC
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>)
file(GLOB_RECURSE SOURCES ./*.cpp ${PROJECT_SOURCE_DIR}/lite/src/pack_model/*.cpp)
file(GLOB_RECURSE SOURCES src/**/*.cpp ${PROJECT_SOURCE_DIR}/lite/src/pack_model/*.cpp)
add_executable(load_and_run ${SOURCES})
target_link_libraries(load_and_run lite_static)
target_link_libraries(load_and_run megbrain)
add_library(lar_object OBJECT ${SOURCES})
target_link_libraries(lar_object lite_static)
target_link_libraries(lar_object megbrain)
if(APPLE)
target_link_libraries(load_and_run gflags)
target_link_libraries(lar_object gflags)
else()
target_link_libraries(load_and_run gflags -Wl,--version-script=${MGE_VERSION_SCRIPT})
target_link_libraries(lar_object gflags -Wl,--version-script=${MGE_VERSION_SCRIPT})
endif()
if(LITE_BUILD_WITH_MGE
AND NOT WIN32
AND NOT APPLE)
# FXIME third_party cpp redis do not support build with clang-cl
target_include_directories(lar_object PRIVATE ${CPP_REDIS_INCLUDES})
endif()
add_executable(load_and_run src/main.cpp)
target_link_libraries(load_and_run lar_object)
if(LITE_BUILD_WITH_RKNPU)
# rknn sdk1.0.0 depend on libc++_shared, use gold to remove NEEDED so symbol check
target_link_options(load_and_run PRIVATE "-fuse-ld=gold")
endif()
if(MGE_WITH_ROCM)
if(LITE_BUILD_WITH_MGE AND MGE_WITH_ROCM)
message(WARNING "MGE_WITH_ROCM is valid link to megdnn")
# FIXME: hip obj can not find cpp obj only through lite_static
target_link_libraries(load_and_run megdnn)
endif()
......@@ -30,17 +42,11 @@ if(UNIX)
endif()
endif()
if(LITE_BUILD_WITH_MGE
AND NOT WIN32
AND NOT APPLE)
# FXIME third_party cpp redis do not support build with clang-cl
target_include_directories(load_and_run PRIVATE ${CPP_REDIS_INCLUDES})
endif()
install(
TARGETS load_and_run
EXPORT ${LITE_EXPORT_TARGETS}
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR})
if(BUILD_SHARED_LIBS)
if(LITE_BUILD_WITH_MGE
AND NOT WIN32
......@@ -48,7 +54,7 @@ if(BUILD_SHARED_LIBS)
# FXIME third_party cpp redis do not support build with clang-cl
list(APPEND SOURCES ${CPP_REDIS_SRCS})
endif()
add_executable(load_and_run_depends_shared ${SOURCES})
add_executable(load_and_run_depends_shared ${SOURCES} src/main.cpp)
target_link_libraries(load_and_run_depends_shared lite_shared)
target_link_libraries(load_and_run_depends_shared gflags)
target_link_libraries(load_and_run_depends_shared megengine)
......@@ -58,7 +64,8 @@ if(BUILD_SHARED_LIBS)
target_link_options(load_and_run_depends_shared PRIVATE "-fuse-ld=gold")
endif()
if(MGE_WITH_ROCM)
if(LITE_BUILD_WITH_MGE AND MGE_WITH_ROCM)
message(WARNING "MGE_WITH_ROCM is valid link to megdnn")
# FIXME: hip obj can not find cpp obj only through lite_static
target_link_libraries(load_and_run_depends_shared megdnn)
endif()
......
......@@ -30,6 +30,8 @@ enum class RunStage {
AFTER_MODEL_RUNNING = 7,
GLOBAL_OPTIMIZATION = 8,
UPDATE_IO = 9,
};
/*!
* \brief: type of different model
......
#include <gflags/gflags.h>
#include <string>
#include "misc.h"
#include "strategys/strategy.h"
std::string simple_usage = R"(
load_and_run: load_and_run <model_path> [options Flags...]
......@@ -29,6 +29,8 @@ More details using "--help" to get!!
)";
int main(int argc, char** argv) {
mgb::set_log_level(mgb::LogLevel::INFO);
lite::set_log_level(LiteLogLevel::INFO);
std::string usage = "load_and_run <model_path> [options Flags...]";
if (argc < 2) {
printf("usage: %s\n", simple_usage.c_str());
......
......@@ -8,17 +8,17 @@ DECLARE_bool(share_param_mem);
using namespace lar;
ModelLite::ModelLite(const std::string& path) : model_path(path) {
LITE_WARN("creat lite model use CPU as default comp node");
LITE_LOG("creat lite model use CPU as default comp node");
};
void ModelLite::load_model() {
m_network = std::make_shared<lite::Network>(config, IO);
if (enable_layout_transform) {
LITE_WARN("enable layout transform while load model for lite");
LITE_LOG("enable layout transform while load model for lite");
lite::Runtime::enable_global_layout_transform(m_network);
}
if (share_model_mem) {
//! WARNNING:maybe not right to share param memmory for this
LITE_WARN("enable share model memory");
LITE_LOG("enable share model memory");
FILE* fin = fopen(model_path.c_str(), "rb");
LITE_ASSERT(fin, "failed to open %s: %s", model_path.c_str(), strerror(errno));
......
......@@ -19,27 +19,27 @@ void XPUDeviceOption::config_model_internel<ModelLite>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if ((enable_cpu) || (enable_cpu_default) || (enable_multithread) ||
(enable_multithread_default)) {
LITE_WARN("using cpu device\n");
LITE_LOG("using cpu device\n");
model->get_config().device_type = LiteDeviceType::LITE_CPU;
}
#if LITE_WITH_CUDA
if (enable_cuda) {
LITE_WARN("using cuda device\n");
LITE_LOG("using cuda device\n");
model->get_config().device_type = LiteDeviceType::LITE_CUDA;
}
#endif
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
auto&& network = model->get_lite_network();
if (enable_cpu_default) {
LITE_WARN("using cpu default device\n");
LITE_LOG("using cpu default device\n");
lite::Runtime::set_cpu_inplace_mode(network);
}
if (enable_multithread) {
LITE_WARN("using multithread device\n");
LITE_LOG("using multithread device\n");
lite::Runtime::set_cpu_threads_number(network, thread_num);
}
if (enable_multithread_default) {
LITE_WARN("using multithread default device\n");
LITE_LOG("using multithread default device\n");
lite::Runtime::set_cpu_inplace_mode(network);
lite::Runtime::set_cpu_threads_number(network, thread_num);
}
......@@ -48,7 +48,7 @@ void XPUDeviceOption::config_model_internel<ModelLite>(
for (auto id : core_ids) {
core_str += std::to_string(id) + ",";
}
LITE_WARN("multi thread core ids: %s\n", core_str.c_str());
LITE_LOG("multi thread core ids: %s\n", core_str.c_str());
lite::ThreadAffinityCallback affinity_callback = [&](size_t thread_id) {
mgb::sys::set_cpu_affinity({core_ids[thread_id]});
};
......@@ -62,14 +62,14 @@ void XPUDeviceOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (enable_cpu) {
mgb_log_warn("using cpu device\n");
mgb_log("using cpu device\n");
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) {
loc.type = mgb::CompNode::DeviceType::CPU;
};
}
#if LITE_WITH_CUDA
if (enable_cuda) {
mgb_log_warn("using cuda device\n");
mgb_log("using cuda device\n");
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) {
if (loc.type == mgb::CompNode::DeviceType::UNSPEC) {
loc.type = mgb::CompNode::DeviceType::CUDA;
......@@ -79,14 +79,14 @@ void XPUDeviceOption::config_model_internel<ModelMdl>(
}
#endif
if (enable_cpu_default) {
mgb_log_warn("using cpu default device\n");
mgb_log("using cpu default device\n");
model->get_mdl_config().comp_node_mapper = [](mgb::CompNode::Locator& loc) {
loc.type = mgb::CompNode::DeviceType::CPU;
loc.device = mgb::CompNode::Locator::DEVICE_CPU_DEFAULT;
};
}
if (enable_multithread) {
mgb_log_warn("using multithread device\n");
mgb_log("using multithread device\n");
model->get_mdl_config().comp_node_mapper =
[&](mgb::CompNode::Locator& loc) {
loc.type = mgb::CompNode::DeviceType::MULTITHREAD;
......@@ -95,7 +95,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>(
};
}
if (enable_multithread_default) {
mgb_log_warn("using multithread default device\n");
mgb_log("using multithread default device\n");
model->get_mdl_config().comp_node_mapper =
[&](mgb::CompNode::Locator& loc) {
loc.type = mgb::CompNode::DeviceType::MULTITHREAD;
......@@ -108,7 +108,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>(
for (auto id : core_ids) {
core_str += std::to_string(id) + ",";
}
mgb_log_warn("set multi thread core ids:%s\n", core_str.c_str());
mgb_log("set multi thread core ids:%s\n", core_str.c_str());
auto affinity_callback = [&](size_t thread_id) {
mgb::sys::set_cpu_affinity({core_ids[thread_id]});
};
......@@ -122,7 +122,7 @@ void XPUDeviceOption::config_model_internel<ModelMdl>(
}
} // namespace lar
XPUDeviceOption::XPUDeviceOption() {
void XPUDeviceOption::update() {
m_option_name = "xpu_device";
enable_cpu = FLAGS_cpu;
#if LITE_WITH_CUDA
......@@ -198,6 +198,7 @@ bool XPUDeviceOption::is_valid() {
std::shared_ptr<OptionBase> XPUDeviceOption::create_option() {
static std::shared_ptr<lar::XPUDeviceOption> option(new XPUDeviceOption);
if (XPUDeviceOption::is_valid()) {
option->update();
return std::static_pointer_cast<lar::OptionBase>(option);
} else {
return nullptr;
......
......@@ -24,8 +24,10 @@ public:
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
XPUDeviceOption();
XPUDeviceOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
bool enable_cpu;
......
......@@ -25,6 +25,7 @@ void COprLibOption::config_model_internel(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (!lib_path.empty()) {
mgb_log("load external C opr lib from %s\n", lib_path.c_str());
load_lib();
}
if (c_opr_args.is_run_c_opr_with_param) {
......@@ -176,7 +177,7 @@ void COprLibOption::set_Copr_IO(std::shared_ptr<ModelBase> model_ptr) {
config_extern_c_opr_dynamic_param(model->get_async_func(), c_opr_param);
}
COprLibOption::COprLibOption() {
void COprLibOption::update() {
m_option_name = "c_opr_lib";
lib_path = FLAGS_c_opr_lib;
c_opr_args.is_run_c_opr = !lib_path.empty();
......@@ -191,6 +192,7 @@ bool COprLibOption::is_valid() {
std::shared_ptr<OptionBase> COprLibOption::create_option() {
static std::shared_ptr<COprLibOption> option(new COprLibOption);
if (COprLibOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......
......@@ -32,8 +32,10 @@ public:
std::string option_name() const override { return m_option_name; };
void update() override;
private:
COprLibOption();
COprLibOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
......
......@@ -25,10 +25,10 @@ void FastRunOption::config_model_internel<ModelLite>(
uint32_t strategy = 0;
#if MGB_ENABLE_FASTRUN
if (enable_full_run) {
LITE_WARN("enable full-run strategy for algo profile");
LITE_LOG("enable full-run strategy for algo profile");
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) | strategy;
} else if (enable_fast_run) {
LITE_WARN("enable fast-run strategy for algo profile");
LITE_LOG("enable fast-run strategy for algo profile");
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_PROFILE) |
static_cast<uint32_t>(Strategy::LITE_ALGO_OPTIMIZED) | strategy;
} else {
......@@ -38,7 +38,7 @@ void FastRunOption::config_model_internel<ModelLite>(
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_HEURISTIC) | strategy;
#endif
if (batch_binary_equal || enable_reproducible) {
LITE_WARN("enable reproducible strategy for algo profile");
LITE_LOG("enable reproducible strategy for algo profile");
if (batch_binary_equal)
strategy = static_cast<uint32_t>(Strategy::LITE_ALGO_REPRODUCIBLE) |
strategy;
......@@ -81,10 +81,10 @@ void FastRunOption::config_model_internel<ModelMdl>(
auto strategy = static_cast<Strategy>(0);
#if MGB_ENABLE_FASTRUN
if (enable_full_run) {
mgb_log_warn("enable full-run strategy for algo profile");
mgb_log("enable full-run strategy for algo profile");
strategy = Strategy::PROFILE | strategy;
} else if (enable_fast_run) {
mgb_log_warn("enable fast-run strategy for algo profile");
mgb_log("enable fast-run strategy for algo profile");
strategy = Strategy::PROFILE | Strategy::OPTIMIZED | strategy;
} else {
strategy = Strategy::HEURISTIC | strategy;
......@@ -93,20 +93,20 @@ void FastRunOption::config_model_internel<ModelMdl>(
strategy = Strategy::HEURISTIC | strategy;
#endif
if (batch_binary_equal || enable_reproducible) {
mgb_log_warn("enable reproducible strategy for algo profile");
mgb_log("enable reproducible strategy for algo profile");
strategy = Strategy::REPRODUCIBLE | strategy;
}
model->set_mdl_strategy(strategy);
//! set binary_equal_between_batch and shared_batch_size
if (batch_binary_equal) {
mgb_log_warn("enable batch binary equal");
mgb_log("enable batch binary equal");
model->get_mdl_config()
.comp_graph->options()
.fast_run_config.binary_equal_between_batch = true;
}
if (share_batch_size > 0) {
mgb_log_warn("set shared shared batch");
mgb_log("set shared shared batch");
model->get_mdl_config()
.comp_graph->options()
.fast_run_config.shared_batch_size = share_batch_size;
......@@ -145,7 +145,7 @@ void FastRunOption::config_model_internel<ModelMdl>(
using namespace lar;
bool FastRunOption::m_valid;
FastRunOption::FastRunOption() {
void FastRunOption::update() {
m_option_name = "fastrun";
#if MGB_ENABLE_FASTRUN
enable_fast_run = FLAGS_fast_run;
......@@ -207,6 +207,7 @@ bool FastRunOption::is_valid() {
std::shared_ptr<OptionBase> FastRunOption::create_option() {
static std::shared_ptr<FastRunOption> option(new FastRunOption);
if (FastRunOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......@@ -250,7 +251,7 @@ DEFINE_bool(
"https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/"
"index.html#reproducibility"
"for more details.");
DEFINE_uint32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun");
DEFINE_int32(fast_run_shared_batch_size, 0, "Set the batch size used during fastrun");
DEFINE_string(fast_run_algo_policy, "", "fast-run cache path.");
REGIST_OPTION_CREATOR(fastrun, lar::FastRunOption::create_option);
......
......@@ -10,7 +10,7 @@ DECLARE_bool(full_run);
#endif
DECLARE_bool(reproducible);
DECLARE_bool(binary_equal_between_batch);
DECLARE_uint32(fast_run_shared_batch_size);
DECLARE_int32(fast_run_shared_batch_size);
DECLARE_string(fast_run_algo_policy);
namespace lar {
......@@ -33,8 +33,10 @@ public:
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
FastRunOption();
FastRunOption() = default;
//! config template for different model
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {}
......
......@@ -93,11 +93,11 @@ void IOdumpOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
if (enable_io_dump) {
LITE_WARN("enable text io dump");
LITE_LOG("enable text io dump");
lite::Runtime::enable_io_txt_dump(model->get_lite_network(), dump_path);
}
if (enable_bin_io_dump) {
LITE_WARN("enable binary io dump");
LITE_LOG("enable binary io dump");
lite::Runtime::enable_io_bin_dump(model->get_lite_network(), dump_path);
}
//! FIX:when add API in lite complate this
......@@ -108,7 +108,7 @@ void IOdumpOption::config_model_internel<ModelLite>(
LITE_THROW("lite model don't support the binary output dump");
}
if (enable_copy_to_host) {
LITE_WARN("lite model set copy to host defaultly");
LITE_LOG("lite model set copy to host defaultly");
}
}
}
......@@ -118,7 +118,7 @@ void IOdumpOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (enable_io_dump) {
mgb_log_warn("enable text io dump");
mgb_log("enable text io dump");
auto iodump = std::make_unique<mgb::TextOprIODump>(
model->get_mdl_config().comp_graph.get(), dump_path.c_str());
iodump->print_addr(false);
......@@ -126,7 +126,7 @@ void IOdumpOption::config_model_internel<ModelMdl>(
}
if (enable_io_dump_stdout) {
mgb_log_warn("enable text io dump to stdout");
mgb_log("enable text io dump to stdout");
std::shared_ptr<FILE> std_out(stdout, [](FILE*) {});
auto iodump = std::make_unique<mgb::TextOprIODump>(
model->get_mdl_config().comp_graph.get(), std_out);
......@@ -135,7 +135,7 @@ void IOdumpOption::config_model_internel<ModelMdl>(
}
if (enable_io_dump_stderr) {
mgb_log_warn("enable text io dump to stderr");
mgb_log("enable text io dump to stderr");
std::shared_ptr<FILE> std_err(stderr, [](FILE*) {});
auto iodump = std::make_unique<mgb::TextOprIODump>(
model->get_mdl_config().comp_graph.get(), std_err);
......@@ -144,14 +144,14 @@ void IOdumpOption::config_model_internel<ModelMdl>(
}
if (enable_bin_io_dump) {
mgb_log_warn("enable binary io dump");
mgb_log("enable binary io dump");
auto iodump = std::make_unique<mgb::BinaryOprIODump>(
model->get_mdl_config().comp_graph.get(), dump_path);
io_dumper = std::move(iodump);
}
if (enable_bin_out_dump) {
mgb_log_warn("enable binary output dump");
mgb_log("enable binary output dump");
out_dumper = std::make_unique<OutputDumper>(dump_path.c_str());
}
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
......@@ -190,7 +190,7 @@ void IOdumpOption::config_model_internel<ModelMdl>(
////////////////////// Input options ////////////////////////
using namespace lar;
InputOption::InputOption() {
void InputOption::update() {
m_option_name = "input";
size_t start = 0;
auto end = FLAGS_input.find(";", start);
......@@ -204,9 +204,10 @@ InputOption::InputOption() {
}
std::shared_ptr<lar::OptionBase> lar::InputOption::create_option() {
static std::shared_ptr<InputOption> m_option(new InputOption);
static std::shared_ptr<InputOption> option(new InputOption);
if (InputOption::is_valid()) {
return std::static_pointer_cast<OptionBase>(m_option);
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
}
......@@ -219,7 +220,7 @@ void InputOption::config_model(
////////////////////// OprIOdump options ////////////////////////
IOdumpOption::IOdumpOption() {
void IOdumpOption::update() {
m_option_name = "iodump";
size_t valid_flag = 0;
if (!FLAGS_io_dump.empty()) {
......@@ -268,6 +269,7 @@ bool IOdumpOption::is_valid() {
std::shared_ptr<OptionBase> IOdumpOption::create_option() {
static std::shared_ptr<IOdumpOption> option(new IOdumpOption);
if (IOdumpOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......
......@@ -30,8 +30,10 @@ public:
//! interface implement from OptionBase
std::string option_name() const override { return m_option_name; };
void update() override;
private:
InputOption();
InputOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
......@@ -50,8 +52,10 @@ public:
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
std::string option_name() const override { return m_option_name; };
void update() override;
private:
IOdumpOption();
IOdumpOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
......
......@@ -11,7 +11,7 @@ void LayoutOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
#define ENABLE_LAYOUT(layout) \
LITE_WARN("enable " #layout " optimization"); \
LITE_LOG("enable " #layout " optimization"); \
model->get_config().options.enable_##layout = true; \
break;
......@@ -51,7 +51,7 @@ void lar::LayoutOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
#define ENABLE_LAYOUT(layout) \
mgb_log_warn("enable " #layout " optimization"); \
mgb_log("enable " #layout " optimization"); \
model->get_mdl_config().comp_graph->options().graph_opt.enable_##layout(); \
break;
......@@ -91,7 +91,7 @@ void lar::LayoutOption::config_model_internel<ModelMdl>(
using namespace lar;
bool LayoutOption::m_valid;
LayoutOption::LayoutOption() {
void LayoutOption::update() {
m_option_name = "layout";
m_option_flag = static_cast<OptLayoutType>(0);
m_option = {
......@@ -157,6 +157,7 @@ bool LayoutOption::is_valid() {
std::shared_ptr<OptionBase> LayoutOption::create_option() {
static std::shared_ptr<LayoutOption> option(new LayoutOption);
if (LayoutOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......@@ -166,16 +167,20 @@ std::shared_ptr<OptionBase> LayoutOption::create_option() {
void LayoutOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
size_t valid_flag = 0;
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])->get_value()) {
if (FLAGS_enable_nchw4 ||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw4"])->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW4);
}
if (std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])->get_value()) {
if (FLAGS_enable_chwn4 ||
std::static_pointer_cast<lar::Bool>(m_option["enable_chwn4"])->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::CHWN4);
}
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])->get_value()) {
if (FLAGS_enable_nchw44 ||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw44"])->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW44);
}
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])->get_value()) {
if (FLAGS_enable_nchw88 ||
std::static_pointer_cast<lar::Bool>(m_option["enable_nchw88"])->get_value()) {
valid_flag |= static_cast<size_t>(OptLayoutType::NCHW88);
}
if (std::static_pointer_cast<lar::Bool>(m_option["enable_nchw32"])->get_value()) {
......
......@@ -37,9 +37,11 @@ public:
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
//! Constructor
LayoutOption();
LayoutOption() = default;
//! configuration for different model implement
template <typename ModelImpl>
......
......@@ -11,7 +11,7 @@ void GoptLayoutOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (m_layout_transform) {
LITE_WARN("using global layout transform optimization\n");
LITE_LOG("using global layout transform optimization\n");
if (m_layout_transform_target ==
mgb::gopt::GraphTuningOptions::Target::CPU) {
model->get_config().device_type = LiteDeviceType::LITE_CPU;
......@@ -98,7 +98,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>(
}
} else if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (m_layout_transform) {
mgb_log_warn("using global layout transform optimization\n");
mgb_log("using global layout transform optimization\n");
auto&& load_result = model->get_mdl_load_result();
load_result.output_var_list = mgb::gopt::layout_transform(
load_result.output_var_list, m_layout_transform_target);
......@@ -150,7 +150,7 @@ void GoptLayoutOption::config_model_internel<ModelMdl>(
using namespace lar;
bool GoptLayoutOption::m_valid;
GoptLayoutOption::GoptLayoutOption() {
void GoptLayoutOption::update() {
m_option_name = "gopt_layout";
if (FLAGS_layout_transform != "cpu"
#if LITE_WITH_CUDA
......@@ -216,6 +216,7 @@ bool GoptLayoutOption::is_valid() {
std::shared_ptr<OptionBase> GoptLayoutOption::create_option() {
static std::shared_ptr<GoptLayoutOption> option(new GoptLayoutOption);
if (GoptLayoutOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......
......@@ -28,8 +28,10 @@ public:
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
GoptLayoutOption();
GoptLayoutOption() = default;
//! config template for different model
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {}
......
......@@ -24,7 +24,7 @@ void PackModelOption::config_model_internel(
using namespace lar;
////////////////////// PackModel options ////////////////////////
PackModelOption::PackModelOption() {
void PackModelOption::update() {
m_option_name = "pack_model";
if (!FLAGS_packed_model_dump.empty())
packed_model_dump = FLAGS_packed_model_dump;
......@@ -45,6 +45,7 @@ bool PackModelOption::is_valid() {
std::shared_ptr<OptionBase> PackModelOption::create_option() {
static std::shared_ptr<PackModelOption> option(new PackModelOption);
if (PackModelOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......
......@@ -19,8 +19,10 @@ public:
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
std::string option_name() const override { return m_option_name; }
void update() override;
private:
PackModelOption();
PackModelOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>);
......
......@@ -15,7 +15,7 @@ void FusePreprocessOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (enable_fuse_preprocess) {
LITE_WARN("enable fuse-preprocess optimization");
LITE_LOG("enable fuse-preprocess optimization");
model->get_config().options.fuse_preprocess = true;
}
}
......@@ -27,7 +27,7 @@ void FusePreprocessOption::config_model_internel<ModelMdl>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& graph_option = model->get_mdl_config().comp_graph->options();
if (enable_fuse_preprocess) {
mgb_log_warn("enable fuse-preprocess optimization");
mgb_log("enable fuse-preprocess optimization");
graph_option.graph_opt.enable_fuse_preprocess();
}
}
......@@ -35,7 +35,7 @@ void FusePreprocessOption::config_model_internel<ModelMdl>(
} // namespace lar
using namespace lar;
bool FusePreprocessOption::m_valid;
FusePreprocessOption::FusePreprocessOption() {
void FusePreprocessOption::update() {
m_option_name = "fuse_preprocess";
enable_fuse_preprocess = FLAGS_enable_fuse_preprocess;
m_option = {{"enable_fuse_preprocess", lar::Bool::make(false)}};
......@@ -51,6 +51,7 @@ bool FusePreprocessOption::is_valid() {
std::shared_ptr<OptionBase> FusePreprocessOption::create_option() {
static std::shared_ptr<FusePreprocessOption> option(new FusePreprocessOption);
if (FusePreprocessOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......@@ -73,7 +74,7 @@ void WeightPreprocessOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (weight_preprocess) {
LITE_WARN("enable weight-preprocess optimization");
LITE_LOG("enable weight-preprocess optimization");
model->get_config().options.weight_preprocess = true;
}
}
......@@ -85,14 +86,14 @@ void WeightPreprocessOption::config_model_internel<ModelMdl>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& graph_option = model->get_mdl_config().comp_graph->options();
if (weight_preprocess) {
mgb_log_warn("enable weight-preprocess optimization");
mgb_log("enable weight-preprocess optimization");
graph_option.graph_opt.enable_weight_preprocess();
}
}
}
} // namespace lar
WeightPreprocessOption::WeightPreprocessOption() {
void WeightPreprocessOption::update() {
m_option_name = "weight_preprocess";
weight_preprocess = FLAGS_weight_preprocess;
m_option = {{"weight_preprocess", lar::Bool::make(false)}};
......@@ -108,6 +109,7 @@ bool WeightPreprocessOption::is_valid() {
std::shared_ptr<OptionBase> WeightPreprocessOption::create_option() {
static std::shared_ptr<WeightPreprocessOption> option(new WeightPreprocessOption);
if (WeightPreprocessOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......@@ -142,14 +144,14 @@ void FuseConvBiasNonlinearOption::config_model_internel<ModelMdl>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& graph_option = model->get_mdl_config().comp_graph->options();
if (enable_fuse_conv_bias_nonlinearity) {
mgb_log_warn("enable fuse conv+bias+nonlinearity optimization");
mgb_log("enable fuse conv+bias+nonlinearity optimization");
graph_option.graph_opt.enable_fuse_conv_bias_nonlinearity();
}
}
}
} // namespace lar
FuseConvBiasNonlinearOption::FuseConvBiasNonlinearOption() {
void FuseConvBiasNonlinearOption::update() {
m_option_name = "fuse_conv_bias_nonlinearity";
enable_fuse_conv_bias_nonlinearity = FLAGS_enable_fuse_conv_bias_nonlinearity;
m_option = {{"enable_fuse_conv_bias_nonlinearity", lar::Bool::make(false)}};
......@@ -166,6 +168,7 @@ std::shared_ptr<OptionBase> FuseConvBiasNonlinearOption::create_option() {
static std::shared_ptr<FuseConvBiasNonlinearOption> option(
new FuseConvBiasNonlinearOption);
if (FuseConvBiasNonlinearOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......@@ -203,14 +206,14 @@ void FuseConvBiasElemwiseAddOption::config_model_internel<ModelMdl>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& graph_option = model->get_mdl_config().comp_graph->options();
if (enable_fuse_conv_bias_with_z) {
mgb_log_warn("enable fuse conv+bias+z optimization");
mgb_log("enable fuse conv+bias+z optimization");
graph_option.graph_opt.enable_fuse_conv_bias_with_z();
}
}
}
} // namespace lar
FuseConvBiasElemwiseAddOption::FuseConvBiasElemwiseAddOption() {
void FuseConvBiasElemwiseAddOption::update() {
m_option_name = "fuse_conv_bias_with_z";
enable_fuse_conv_bias_with_z = FLAGS_enable_fuse_conv_bias_with_z;
m_option = {{"enable_fuse_conv_bias_with_z", lar::Bool::make(false)}};
......@@ -227,6 +230,7 @@ std::shared_ptr<OptionBase> FuseConvBiasElemwiseAddOption::create_option() {
static std::shared_ptr<FuseConvBiasElemwiseAddOption> option(
new FuseConvBiasElemwiseAddOption);
if (FuseConvBiasElemwiseAddOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......@@ -250,26 +254,26 @@ void GraphRecordOption::config_model_internel<ModelLite>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& config_option = model->get_config().options;
if (const_shape) {
LITE_WARN("enable const var shape");
LITE_LOG("enable const var shape");
config_option.const_shape = true;
}
if (fake_first) {
LITE_WARN("enable fake-first optimization");
LITE_LOG("enable fake-first optimization");
config_option.fake_next_exec = true;
}
if (no_sanity_check) {
LITE_WARN("disable var sanity check optimization");
LITE_LOG("disable var sanity check optimization");
config_option.var_sanity_check_first_run = false;
}
if (m_record_comp_seq == 1) {
LITE_WARN("set record_comp_seq_level to 1");
LITE_LOG("set record_comp_seq_level to 1");
}
if (m_record_comp_seq == 2) {
mgb_assert(
no_sanity_check,
"--no-sanity-check should be set before "
"--record-comp-seq2");
LITE_WARN("set record_comp_seq_level to 2");
LITE_LOG("set record_comp_seq_level to 2");
}
config_option.comp_node_seq_record_level = m_record_comp_seq;
}
......@@ -281,33 +285,33 @@ void GraphRecordOption::config_model_internel<ModelMdl>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& graph_option = model->get_mdl_config().comp_graph->options();
if (const_shape) {
mgb_log_warn("enable const var shape");
mgb_log("enable const var shape");
model->get_mdl_config().const_var_shape = true;
}
if (fake_first) {
mgb_log_warn("enable fake-first optimization");
mgb_log("enable fake-first optimization");
graph_option.fake_next_exec = true;
}
if (no_sanity_check) {
mgb_log_warn("disable var sanity check optimization");
mgb_log("disable var sanity check optimization");
graph_option.var_sanity_check_first_run = false;
}
if (m_record_comp_seq == 1) {
mgb_log_warn("set record_comp_seq_level to 1");
mgb_log("set record_comp_seq_level to 1");
}
if (m_record_comp_seq == 2) {
mgb_assert(
no_sanity_check && !fake_first,
"--no-sanity-check should be set before "
"--record-comp-seq2 and --fake-first should not be set");
mgb_log_warn("set record_comp_seq_level to 2");
mgb_log("set record_comp_seq_level to 2");
}
graph_option.comp_node_seq_record_level = m_record_comp_seq;
}
}
} // namespace lar
GraphRecordOption::GraphRecordOption() {
void GraphRecordOption::update() {
m_option_name = "graph_record";
m_record_comp_seq = 0;
const_shape = FLAGS_const_shape;
......@@ -350,6 +354,7 @@ bool GraphRecordOption::is_valid() {
std::shared_ptr<OptionBase> GraphRecordOption::create_option() {
static std::shared_ptr<GraphRecordOption> option(new GraphRecordOption);
if (GraphRecordOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......@@ -387,7 +392,7 @@ void MemoryOptimizeOption::config_model_internel<ModelLite>(
}
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
if (workspace_limit != SIZE_MAX) {
LITE_WARN("set workspace limit to %ld", workspace_limit);
LITE_LOG("set workspace limit to %ld", workspace_limit);
lite::Runtime::set_network_algo_workspace_limit(
model->get_lite_network(), workspace_limit);
}
......@@ -400,12 +405,12 @@ void MemoryOptimizeOption::config_model_internel<ModelMdl>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& graph_option = model->get_mdl_config().comp_graph->options();
if (disable_mem_opt) {
mgb_log_warn("disable memory optimization");
mgb_log("disable memory optimization");
graph_option.seq_opt.enable_mem_plan_opt = false;
graph_option.seq_opt.enable_mem_reuse_alloc = false;
}
if (workspace_limit < SIZE_MAX) {
mgb_log_warn("set workspace limit to %ld", workspace_limit);
mgb_log("set workspace limit to %ld", workspace_limit);
auto&& output_spec = model->get_output_spec();
mgb::SymbolVarArray vars;
for (auto i : output_spec) {
......@@ -417,7 +422,7 @@ void MemoryOptimizeOption::config_model_internel<ModelMdl>(
}
} // namespace lar
MemoryOptimizeOption::MemoryOptimizeOption() {
void MemoryOptimizeOption::update() {
m_option_name = "memory_optimize";
disable_mem_opt = FLAGS_disable_mem_opt;
workspace_limit = FLAGS_workspace_limit;
......@@ -432,6 +437,7 @@ bool MemoryOptimizeOption::is_valid() {
std::shared_ptr<OptionBase> MemoryOptimizeOption::create_option() {
static std::shared_ptr<MemoryOptimizeOption> option(new MemoryOptimizeOption);
if (MemoryOptimizeOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......@@ -451,7 +457,7 @@ void JITOption::config_model_internel<ModelLite>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& config_option = model->get_config().options;
if (enable_jit) {
LITE_WARN("enable JIT (level 1)");
LITE_LOG("enable JIT (level 1)");
config_option.jit_level = 1;
}
}
......@@ -463,13 +469,13 @@ void JITOption::config_model_internel<ModelMdl>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& graph_option = model->get_mdl_config().comp_graph->options();
if (enable_jit) {
mgb_log_warn("enable JIT (level 1)");
mgb_log("enable JIT (level 1)");
graph_option.graph_opt.jit = 1;
}
}
}
} // namespace lar
JITOption::JITOption() {
void JITOption::update() {
m_option_name = "JIT";
enable_jit = FLAGS_enable_jit;
}
......@@ -482,6 +488,7 @@ bool JITOption::is_valid() {
std::shared_ptr<OptionBase> JITOption::create_option() {
static std::shared_ptr<JITOption> option(new JITOption);
if (JITOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......@@ -500,12 +507,12 @@ void TensorRTOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (!tensorrt_cache.empty()) {
LITE_WARN("set tensorrt cache as %s", tensorrt_cache.c_str());
LITE_LOG("set tensorrt cache as %s", tensorrt_cache.c_str());
lite::set_tensor_rt_cache(tensorrt_cache);
}
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
if (enable_tensorrt) {
LITE_WARN("enable TensorRT");
LITE_LOG("enable TensorRT");
lite::Runtime::use_tensorrt(model->get_lite_network());
}
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
......@@ -521,11 +528,11 @@ void TensorRTOption::config_model_internel<ModelMdl>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& graph_option = model->get_mdl_config().comp_graph->options();
if (enable_tensorrt) {
mgb_log_warn("using tensorRT");
mgb_log("using tensorRT");
graph_option.graph_opt.tensorrt = true;
}
if (!tensorrt_cache.empty()) {
mgb_log_warn("use tensorrt cache: %s", tensorrt_cache.c_str());
mgb_log("use tensorrt cache: %s", tensorrt_cache.c_str());
mgb::TensorRTEngineCache::enable_engine_cache(true);
mgb::TensorRTEngineCache::set_impl(
std::make_shared<mgb::TensorRTEngineCacheIO>(
......@@ -541,7 +548,7 @@ void TensorRTOption::config_model_internel<ModelMdl>(
}
} // namespace lar
TensorRTOption::TensorRTOption() {
void TensorRTOption::update() {
m_option_name = "tensorRT";
enable_tensorrt = FLAGS_tensorrt;
tensorrt_cache = FLAGS_tensorrt_cache;
......@@ -556,6 +563,7 @@ bool TensorRTOption::is_valid() {
std::shared_ptr<OptionBase> TensorRTOption::create_option() {
static std::shared_ptr<TensorRTOption> option(new TensorRTOption);
if (TensorRTOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......
......@@ -39,8 +39,10 @@ public:
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
FusePreprocessOption();
FusePreprocessOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
......@@ -65,8 +67,10 @@ public:
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
WeightPreprocessOption();
WeightPreprocessOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
......@@ -91,8 +95,10 @@ public:
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
FuseConvBiasNonlinearOption();
FuseConvBiasNonlinearOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
......@@ -117,8 +123,10 @@ public:
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
FuseConvBiasElemwiseAddOption();
FuseConvBiasElemwiseAddOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
std::string m_option_name;
......@@ -143,8 +151,10 @@ public:
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
GraphRecordOption();
GraphRecordOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
......@@ -169,8 +179,10 @@ public:
std::string option_name() const override { return m_option_name; };
void update() override;
private:
MemoryOptimizeOption();
MemoryOptimizeOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
......@@ -191,8 +203,10 @@ public:
std::string option_name() const override { return m_option_name; };
void update() override;
private:
JITOption();
JITOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
......@@ -212,8 +226,10 @@ public:
std::string option_name() const override { return m_option_name; };
void update() override;
private:
TensorRTOption();
TensorRTOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
......
......@@ -28,6 +28,10 @@ public:
//! get option map
virtual OptionValMap* get_option() { return nullptr; }
//! update option value
virtual void update(){};
virtual ~OptionBase() = default;
};
......
......@@ -22,10 +22,10 @@ void PluginOption::config_model_internel<ModelLite>(
else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
if (!profile_path.empty()) {
if (!enable_profile_host) {
LITE_WARN("enable profiling");
LITE_LOG("enable profiling");
model->get_lite_network()->enable_profile_performance(profile_path);
} else {
LITE_WARN("enable profiling for host");
LITE_LOG("enable profiling for host");
model->get_lite_network()->enable_profile_performance(profile_path);
}
}
......@@ -39,18 +39,18 @@ void PluginOption::config_model_internel<ModelMdl>(
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& config = model->get_mdl_config();
if (range > 0) {
mgb_log_warn("enable number range check");
mgb_log("enable number range check");
model->set_num_range_checker(float(range));
}
if (enable_check_dispatch) {
mgb_log_warn("enable cpu dispatch check");
mgb_log("enable cpu dispatch check");
cpu_dispatch_checker =
std::make_unique<mgb::CPUDispatchChecker>(config.comp_graph.get());
}
if (!var_value_check_str.empty()) {
mgb_log_warn("enable variable value check");
mgb_log("enable variable value check");
size_t init_idx = 0, switch_interval;
auto sep = var_value_check_str.find(':');
if (sep != std::string::npos) {
......@@ -67,9 +67,9 @@ void PluginOption::config_model_internel<ModelMdl>(
if (!profile_path.empty()) {
if (!enable_profile_host) {
mgb_log_warn("enable profiling");
mgb_log("enable profiling");
} else {
mgb_log_warn("enable profiling for host");
mgb_log("enable profiling for host");
}
model->set_profiler();
}
......@@ -79,12 +79,11 @@ void PluginOption::config_model_internel<ModelMdl>(
else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
#if MGB_ENABLE_JSON
if (!profile_path.empty()) {
mgb_log_warn("filename %s", profile_path.c_str());
if (model->get_profiler()) {
model->get_profiler()
->to_json_full(model->get_async_func().get())
->writeto_fpath(profile_path);
mgb_log_warn("profiling result written to %s", profile_path.c_str());
mgb_log("profiling result written to %s", profile_path.c_str());
}
}
#endif
......@@ -94,7 +93,7 @@ void PluginOption::config_model_internel<ModelMdl>(
} // namespace lar
using namespace lar;
PluginOption::PluginOption() {
void PluginOption::update() {
m_option_name = "plugin";
range = FLAGS_range;
enable_check_dispatch = FLAGS_check_dispatch;
......@@ -125,6 +124,7 @@ bool PluginOption::is_valid() {
std::shared_ptr<OptionBase> PluginOption::create_option() {
static std::shared_ptr<PluginOption> option(new PluginOption);
if (PluginOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......@@ -199,7 +199,7 @@ void DebugOption::format_and_print(
std::stringstream ss;
ss << table;
printf("%s\n\n", ss.str().c_str());
LITE_LOG("%s\n\n", ss.str().c_str());
}
template <>
......@@ -243,7 +243,7 @@ void DebugOption::format_and_print(
std::stringstream ss;
ss << table;
printf("%s\n\n", ss.str().c_str());
mgb_log("%s\n\n", ss.str().c_str());
}
template <>
......@@ -260,7 +260,7 @@ void DebugOption::config_model_internel<ModelLite>(
#endif
#endif
if (enable_verbose) {
LITE_WARN("enable verbose");
LITE_LOG("enable verbose");
lite::set_log_level(LiteLogLevel::DEBUG);
}
......@@ -272,7 +272,7 @@ void DebugOption::config_model_internel<ModelLite>(
#endif
} else if (runtime_param.stage == RunStage::AFTER_MODEL_LOAD) {
if (enable_display_model_info) {
LITE_WARN("enable display model information");
LITE_LOG("enable display model information");
format_and_print<ModelLite>("Runtime Model Info", model);
}
} else if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
......@@ -287,7 +287,7 @@ void DebugOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (enable_verbose) {
mgb_log_warn("enable verbose");
mgb_log("enable verbose");
mgb::set_log_level(mgb::LogLevel::DEBUG);
}
......@@ -299,21 +299,21 @@ void DebugOption::config_model_internel<ModelMdl>(
#endif
} else if (runtime_param.stage == RunStage::BEFORE_OUTSPEC_SET) {
if (enable_display_model_info) {
mgb_log_warn("enable display model information");
mgb_log("enable display model information");
format_and_print<ModelMdl>("Runtime Model Info", model);
}
} else if (runtime_param.stage == RunStage::AFTER_OUTSPEC_SET) {
#ifndef __IN_TEE_ENV__
#if MGB_ENABLE_JSON
if (!static_mem_log_dir_path.empty()) {
mgb_log_warn("enable get static memeory information");
mgb_log("enable get static memeory information");
model->get_async_func()->get_static_memory_alloc_info(
static_mem_log_dir_path);
}
#endif
#endif
if (disable_assert_throw) {
mgb_log_warn("disable assert throw");
mgb_log("disable assert throw");
auto on_opr = [](mgb::cg::OperatorNodeBase* opr) {
if (opr->same_type<mgb::opr::AssertEqual>()) {
opr->cast_final<mgb::opr::AssertEqual>().disable_throw_on_error();
......@@ -333,7 +333,7 @@ void DebugOption::config_model_internel<ModelMdl>(
} // namespace lar
DebugOption::DebugOption() {
void DebugOption::update() {
m_option_name = "debug";
enable_display_model_info = FLAGS_model_info;
enable_verbose = FLAGS_verbose;
......@@ -367,6 +367,7 @@ bool DebugOption::is_valid() {
std::shared_ptr<OptionBase> DebugOption::create_option() {
static std::shared_ptr<DebugOption> option(new DebugOption);
if (DebugOption::is_valid()) {
option->update();
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
......
......@@ -44,8 +44,10 @@ public:
std::string option_name() const override { return m_option_name; };
void update() override;
private:
PluginOption();
PluginOption() = default;
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
double range;
......@@ -74,8 +76,10 @@ public:
std::string option_name() const override { return m_option_name; };
void update() override;
private:
DebugOption();
DebugOption() = default;
template <typename ModelImpl>
void format_and_print(const std::string&, std::shared_ptr<ModelImpl>){};
template <typename ModelImpl>
......
......@@ -5,7 +5,7 @@ using namespace lar;
DECLARE_bool(c_opr_lib_with_param);
DECLARE_bool(fitting);
StrategyOption::StrategyOption() {
void StrategyOption::update() {
m_option_name = "run_strategy";
warmup_iter = FLAGS_fitting ? 3 : FLAGS_warmup_iter;
run_iter = FLAGS_fitting ? 10 : FLAGS_iter;
......@@ -20,6 +20,7 @@ StrategyOption::StrategyOption() {
std::shared_ptr<OptionBase> StrategyOption::create_option() {
static std::shared_ptr<StrategyOption> option(new StrategyOption);
option->update();
return std::static_pointer_cast<OptionBase>(option);
}
......@@ -43,12 +44,13 @@ void StrategyOption::config_model(
}
}
TestcaseOption::TestcaseOption() {
void TestcaseOption::update() {
m_option_name = "run_testcase";
}
std::shared_ptr<OptionBase> TestcaseOption::create_option() {
static std::shared_ptr<TestcaseOption> option(new TestcaseOption);
option->update();
return std::static_pointer_cast<OptionBase>(option);
}
......
......@@ -25,9 +25,11 @@ public:
OptionValMap* get_option() override { return &m_option; }
void update() override;
private:
//! Constructor
StrategyOption();
StrategyOption() = default;
//! configuration for different model implement
std::string m_option_name;
......@@ -52,9 +54,11 @@ public:
//! get option name
std::string option_name() const override { return m_option_name; };
void update() override;
private:
//! Constructor
TestcaseOption();
TestcaseOption() = default;
//! configuration for different model implement
std::string m_option_name;
......
......@@ -233,8 +233,8 @@ void OptionsTimeProfiler::profile_with_given_options(
"the log) when profile option:\n%s\n",
option_code.c_str());
} else {
printf("profile option:\n%s\naverage time = %.2f\n", option_code.c_str(),
average);
mgb_log("profile option:\n%s\naverage time = %.2f\n", option_code.c_str(),
average);
//! record profile result
m_options_profile_result.insert({option_code, average});
......@@ -370,7 +370,6 @@ void UserInfoParser::parse_info(std::shared_ptr<OptionsFastManager>& manager) {
FittingStrategy::FittingStrategy(std::string model_path) {
m_manager = std::make_shared<OptionsFastManager>();
m_dumped_model = FLAGS_dump_fitting_model;
mgb::set_log_level(mgb::LogLevel::INFO);
m_options = std::make_shared<OptionMap>();
m_model_path = model_path;
auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map();
......@@ -518,10 +517,10 @@ void FittingStrategy::AutoCleanFile::dump_model() {
void FittingStrategy::run() {
auto mgb_version = mgb::get_version();
auto dnn_version = megdnn::get_version();
printf("megbrain/lite/load_and_run:\nusing MegBrain "
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n",
mgb_version.major, mgb_version.minor, mgb_version.patch, mgb_version.is_dev,
dnn_version.major, dnn_version.minor, dnn_version.patch);
mgb_log("megbrain/lite/load_and_run:\nusing MegBrain "
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n",
mgb_version.major, mgb_version.minor, mgb_version.patch, mgb_version.is_dev,
dnn_version.major, dnn_version.minor, dnn_version.patch);
// ! create profiler with given user info
m_info_parser.get_user_info();
m_info_parser.parse_info(m_manager);
......
......@@ -5,13 +5,10 @@
#include "megbrain/utils/timer.h"
#include "megbrain/version.h"
#include "megdnn/version.h"
#include "misc.h"
using namespace lar;
NormalStrategy::NormalStrategy(std::string model_path) {
mgb::set_log_level(mgb::LogLevel::WARN);
lite::set_log_level(LiteLogLevel::WARN);
m_options = std::make_shared<OptionMap>();
m_model_path = model_path;
auto option_creator_map = OptionFactory::get_Instance().get_option_creator_map();
......@@ -47,7 +44,7 @@ void NormalStrategy::run_subline() {
mgb::RealTimer timer;
model->load_model();
printf("load model: %.3fms\n", timer.get_msecs_reset());
mgb_log("load model: %.3fms\n", timer.get_msecs_reset());
//! after load configure
auto config_after_load = [&]() {
......@@ -62,10 +59,10 @@ void NormalStrategy::run_subline() {
auto warm_up = [&]() {
auto warmup_num = m_runtime_param.warmup_iter;
for (size_t i = 0; i < warmup_num; i++) {
printf("=== prepare: %.3fms; going to warmup\n\n", timer.get_msecs_reset());
mgb_log("=== prepare: %.3fms; going to warmup", timer.get_msecs_reset());
model->run_model();
model->wait();
printf("warm up %lu %.3fms\n", i, timer.get_msecs_reset());
mgb_log("warm up %lu %.3fms", i, timer.get_msecs_reset());
m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT;
stage_config_model();
}
......@@ -83,21 +80,21 @@ void NormalStrategy::run_subline() {
auto cur = timer.get_msecs();
m_runtime_param.stage = RunStage::AFTER_RUNNING_WAIT;
stage_config_model();
printf("iter %lu/%lu: e2e=%.3f ms (host=%.3f ms)\n", i, run_num, cur,
exec_time);
mgb_log("iter %lu/%lu: e2e=%.3f ms (host=%.3f ms)", i, run_num, cur,
exec_time);
time_sum += cur;
time_sqrsum += cur * cur;
fflush(stdout);
min_time = std::min(min_time, cur);
max_time = std::max(max_time, cur);
}
printf("\n=== finished test #%u: time=%.3f ms avg_time=%.3f ms "
"standard_deviation=%.3f ms min=%.3f ms max=%.3f ms\n\n",
idx, time_sum, time_sum / run_num,
std::sqrt(
(time_sqrsum * run_num - time_sum * time_sum) /
(run_num * (run_num - 1))),
min_time, max_time);
mgb_log("=== finished test #%u: time=%.3f ms avg_time=%.3f ms "
"standard_deviation=%.3f ms min=%.3f ms max=%.3f ms",
idx, time_sum, time_sum / run_num,
std::sqrt(
(time_sqrsum * run_num - time_sum * time_sum) /
(run_num * (run_num - 1))),
min_time, max_time);
return time_sum;
};
......@@ -122,7 +119,7 @@ void NormalStrategy::run_subline() {
stage_config_model();
}
printf("=== total time: %.3fms\n", tot_time);
mgb_log("=== total time: %.3fms\n", tot_time);
//! execute after run
m_runtime_param.stage = RunStage::AFTER_MODEL_RUNNING;
stage_config_model();
......@@ -131,9 +128,9 @@ void NormalStrategy::run_subline() {
void NormalStrategy::run() {
auto v0 = mgb::get_version();
auto v1 = megdnn::get_version();
printf("megbrain/lite/load_and_run:\nusing MegBrain "
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n",
v0.major, v0.minor, v0.patch, v0.is_dev, v1.major, v1.minor, v1.patch);
mgb_log("megbrain/lite/load_and_run:\nusing MegBrain "
"%d.%d.%d(%d) and MegDNN %d.%d.%d\n",
v0.major, v0.minor, v0.patch, v0.is_dev, v1.major, v1.minor, v1.patch);
size_t thread_num = m_runtime_param.threads;
auto run_sub = [&]() { run_subline(); };
......
......@@ -73,7 +73,7 @@ public:
#define LITE_LOG_(level, msg...) (void)0
#endif
#define LITE_LOG(fmt...) LITE_LOG_(DEBUG, fmt);
#define LITE_LOG(fmt...) LITE_LOG_(INFO, fmt);
#define LITE_DEBUG(fmt...) LITE_LOG_(DEBUG, fmt);
#define LITE_WARN(fmt...) LITE_LOG_(WARN, fmt);
#define LITE_ERROR(fmt...) LITE_LOG_(ERROR, fmt);
......
if(MGE_WITH_TEST)
include_directories(PUBLIC
$<BUILD_INTERFACE:${PROJECT_SOURCE_DIR}/lite/load_and_run/src>)
file(GLOB_RECURSE SOURCES ./*.cpp main.cpp)
add_executable(lite_test ${SOURCES})
target_link_libraries(lite_test lar_object)
target_link_libraries(lite_test gtest)
target_link_libraries(lite_test lite_static)
if(LITE_BUILD_WITH_MGE)
# lite_test will depends megbrain interface
target_link_libraries(lite_test megbrain)
if(MGE_WITH_ROCM)
# FIXME: hip obj can not find cpp obj only through lite_static
message(WARNING "MGE_WITH_ROCM is valid link to megdnn")
target_link_libraries(lite_test megdnn)
endif()
endif()
......
#include <gtest/gtest.h>
#include <string.h>
#include <memory>
#include "test_options.h"
using namespace lar;
DECLARE_bool(lite);
DECLARE_bool(cpu);
#if LITE_WITH_CUDA
DECLARE_bool(cuda);
#endif
DECLARE_bool(enable_nchw4);
DECLARE_bool(enable_chwn4);
DECLARE_bool(enable_nchw44);
DECLARE_bool(enable_nchw88);
DECLARE_bool(enable_nchw32);
DECLARE_bool(enable_nchw64);
DECLARE_bool(enable_nhwcd4);
DECLARE_bool(enable_nchw44_dot);
namespace {
BOOL_OPTION_WRAP(enable_nchw4);
BOOL_OPTION_WRAP(enable_chwn4);
BOOL_OPTION_WRAP(enable_nchw44);
BOOL_OPTION_WRAP(enable_nchw88);
BOOL_OPTION_WRAP(enable_nchw32);
BOOL_OPTION_WRAP(enable_nchw64);
BOOL_OPTION_WRAP(enable_nhwcd4);
BOOL_OPTION_WRAP(enable_nchw44_dot);
BOOL_OPTION_WRAP(lite);
BOOL_OPTION_WRAP(cpu);
#if LITE_WITH_CUDA
BOOL_OPTION_WRAP(cuda);
#endif
} // anonymous namespace
TEST(TestLarLayout, X86_CPU) {
DEFINE_WRAP(cpu);
std::string model_path = "./shufflenet.mge";
TEST_BOOL_OPTION(enable_nchw4);
TEST_BOOL_OPTION(enable_chwn4);
TEST_BOOL_OPTION(enable_nchw44);
TEST_BOOL_OPTION(enable_nchw44_dot);
TEST_BOOL_OPTION(enable_nchw64);
TEST_BOOL_OPTION(enable_nchw32);
TEST_BOOL_OPTION(enable_nchw88);
}
TEST(TestLarLayout, X86_CPU_LITE) {
DEFINE_WRAP(cpu);
DEFINE_WRAP(lite);
std::string model_path = "./shufflenet.mge";
TEST_BOOL_OPTION(enable_nchw4);
TEST_BOOL_OPTION(enable_nchw44);
TEST_BOOL_OPTION(enable_nchw44_dot);
TEST_BOOL_OPTION(enable_nchw64);
TEST_BOOL_OPTION(enable_nchw32);
TEST_BOOL_OPTION(enable_nchw88);
}
#if LITE_WITH_CUDA
TEST(TestLarLayout, CUDA) {
DEFINE_WRAP(cuda);
std::string model_path = "./shufflenet.mge";
TEST_BOOL_OPTION(enable_nchw4);
TEST_BOOL_OPTION(enable_chwn4);
TEST_BOOL_OPTION(enable_nchw64);
TEST_BOOL_OPTION(enable_nchw32);
FLAGS_cuda = false;
}
TEST(TestLarLayout, CUDA_LITE) {
DEFINE_WRAP(cuda);
DEFINE_WRAP(lite);
std::string model_path = "./shufflenet.mge";
TEST_BOOL_OPTION(enable_nchw4);
TEST_BOOL_OPTION(enable_nchw64);
TEST_BOOL_OPTION(enable_nchw32);
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#include "test_options.h"
using namespace lar;
void lar::run_NormalStrategy(std::string model_path) {
auto origin_level = mgb::get_log_level();
mgb::set_log_level(mgb::LogLevel::WARN);
NormalStrategy strategy(model_path);
strategy.run();
mgb::set_log_level(origin_level);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#pragma once
#include <iostream>
#include <thread>
#include "../load_and_run/src/strategys/strategy.h"
#include "../load_and_run/src/strategys/strategy_normal.h"
#include "megbrain/common.h"
#include "megbrain/utils/timer.h"
#include "megbrain/version.h"
#include "megdnn/version.h"
#include "misc.h"
namespace lar {
//! run load_and_run NormalStrategy to test different options
void run_NormalStrategy(std::string model_path);
} // namespace lar
#define BOOL_OPTION_WRAP(option) \
struct BoolOptionWrap_##option { \
BoolOptionWrap_##option() { FLAGS_##option = true; } \
~BoolOptionWrap_##option() { FLAGS_##option = false; } \
};
#define DEFINE_WRAP(option) BoolOptionWrap_##option flags_##option;
#define TEST_BOOL_OPTION(option) \
{ \
BoolOptionWrap_##option flags_##option; \
run_NormalStrategy(model_path); \
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册