提交 2d54ad18 编写于 作者: M Megvii Engine Team

feat(lite): add global layout transform interface for load and run

GitOrigin-RevId: 65c2430ec2c93a48d633d3b638bac885e81ce52e
上级 ba2f0c2e
......@@ -138,7 +138,7 @@ void ConvBiasForwardImpl::AlgoCUDNNConv::exec(const ExecArgs& args) const {
if (args.z_layout->ndim > 0) {
auto z_tensor = *args.z_tensor;
if (args.z_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
z_tensor.raw_ptr = bundle.get(2);
z_tensor = TensorND{bundle.get(2), args.z_tensor->layout};
z_tensor.layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype,
......
......@@ -36,6 +36,8 @@ enum class RunStage {
AFTER_RUNNING_ITER = 6,
AFTER_MODEL_RUNNING = 7,
GLOBAL_OPTIMIZATION = 8,
};
/*!
* \brief: type of different model
......
......@@ -52,15 +52,15 @@ void ModelMdl::load_model() {
m_model_file->read(&testcase_num, sizeof(testcase_num));
}
auto format =
m_format =
mgb::serialization::GraphLoader::identify_graph_dump_format(*m_model_file);
mgb_assert(
format.valid(),
m_format.valid(),
"invalid format, please make sure model is dumped by GraphDumper");
//! load computing graph of model
m_loader = mgb::serialization::GraphLoader::make(
std::move(m_model_file), format.val());
std::move(m_model_file), m_format.val());
m_load_result = m_loader->load(m_load_config, false);
m_load_config.comp_graph.reset();
......@@ -87,9 +87,15 @@ void ModelMdl::make_output_spec() {
m_asyc_exec = m_load_result.graph_compile(m_output_spec);
}
std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader() {
m_loader = mgb::serialization::GraphLoader::make(
m_loader->reset_file(), m_loader->format());
std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader(
std::unique_ptr<mgb::serialization::InputFile> input_file) {
if (input_file) {
m_loader = mgb::serialization::GraphLoader::make(
std::move(input_file), m_loader->format());
} else {
m_loader = mgb::serialization::GraphLoader::make(
m_loader->reset_file(), m_loader->format());
}
return m_loader;
}
......
......@@ -50,8 +50,16 @@ public:
//! get load config for megDL model
mgb::serialization::GraphLoadConfig& get_mdl_config() { return m_load_config; }
//! reset the graph loader for dump_with_testcase model
std::shared_ptr<mgb::serialization::GraphLoader>& reset_loader();
/*! reset the underlying graph loader from which further load() would read()
*
* \param input_file new input_file, can be null
* \return new loader
*/
std::shared_ptr<mgb::serialization::GraphLoader>& reset_loader(
std::unique_ptr<mgb::serialization::InputFile> input_file = {});
//! get the underlying graph loader
std::shared_ptr<mgb::serialization::GraphLoader>& get_loader() { return m_loader; }
//! algo strategy for runing model
void set_mdl_strategy(Strategy& u_strategy) { m_strategy = u_strategy; }
......@@ -88,11 +96,18 @@ public:
m_load_config.comp_graph.get(), range);
}
std::unique_ptr<mgb::serialization::GraphDumper> get_dumper(
std::unique_ptr<mgb::serialization::OutputFile> out_file) {
return mgb::serialization::GraphDumper::make(
std::move(out_file), m_format.val());
}
private:
bool share_model_mem;
std::string model_path;
std::unique_ptr<mgb::serialization::InputFile> m_model_file;
mgb::serialization::GraphLoadConfig m_load_config;
mgb::Maybe<mgb::serialization::GraphDumpFormat> m_format;
mgb::serialization::GraphLoader::LoadResult m_load_result;
std::shared_ptr<mgb::serialization::GraphLoader> m_loader;
......
/**
* \file lite/load_and_run/src/options/layout_trans_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/
#include "layout_trans_options.h"
#include <gflags/gflags.h>
#include "megbrain/serialization/serializer.h"
#include "misc.h"
#include "models/model_lite.h"
#include "models/model_mdl.h"
namespace lar {
template <>
void GoptLayoutOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite> /* model */) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
LITE_THROW("lite model don't support global graph optimization");
}
}
template <>
void GoptLayoutOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::GLOBAL_OPTIMIZATION) {
if (layout_transform) {
auto&& load_result = model->get_mdl_load_result();
load_result.output_var_list = mgb::gopt::layout_transform(
load_result.output_var_list, layout_transform_target);
if (!layout_transform_dump_file.empty()) {
auto out_file = mgb::serialization::OutputFile::make_fs(
layout_transform_dump_file.c_str(), 'w');
auto testcase_num = model->get_testcase_num();
if (testcase_num) {
const char* magic = "mgbtest0";
constexpr size_t len = sizeof(magic);
out_file->write(magic, len);
out_file->write(&testcase_num, sizeof(testcase_num));
}
using DumpConfig = mgb::serialization::GraphDumper::DumpConfig;
DumpConfig config{1, false, false};
auto dumper = model->get_dumper(std::move(out_file));
dumper->dump(load_result.output_var_list, config);
if (testcase_num) {
auto input_file = model->get_loader()->reset_file();
auto current_offset = input_file->tell();
auto loader = model->reset_loader(std::move(input_file));
auto testcase = loader->load(model->get_mdl_config(), false);
mgb::serialization::GraphDumper::DumpConfig config{1, false, false};
for (size_t i = 0; i < testcase_num; ++i) {
auto casefile = mgb::serialization::OutputFile::make_fs(
layout_transform_dump_file.c_str(), 'a');
auto casedumper = model->get_dumper(std::move(casefile));
casedumper->dump(testcase.output_var_list, config);
if (i != testcase_num - 1) {
loader = model->reset_loader();
testcase = loader->load(model->get_mdl_config(), false);
}
}
input_file = model->get_loader()->reset_file();
input_file->rewind();
input_file->skip(current_offset);
model->reset_loader(std::move(input_file));
}
}
}
}
}
} // namespace lar
using namespace lar;
GoptLayoutOption::GoptLayoutOption() {
m_option_name = "gopt_layout";
if (FLAGS_layout_transform != "cuda" && FLAGS_layout_transform != "cpu" &&
FLAGS_layout_transform != "opencl") {
layout_transform = false;
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::UNSPEC;
} else {
layout_transform = true;
if (FLAGS_layout_transform == "cuda") {
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CUDA;
} else if (FLAGS_layout_transform == "cpu") {
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::CPU;
} else if (FLAGS_layout_transform == "opencl") {
layout_transform_target = mgb::gopt::GraphTuningOptions::Target::OPENCL;
}
}
layout_transform_dump_file = FLAGS_layout_transform_dump;
}
bool GoptLayoutOption::is_valid() {
bool ret = false;
if (!FLAGS_layout_transform.empty()) {
if (FLAGS_layout_transform != "cuda" && FLAGS_layout_transform != "cpu" &&
FLAGS_layout_transform != "opencl") {
mgb_assert(
false,
"unsupported target(got:%s) for global layout "
"transform",
FLAGS_layout_transform.c_str());
ret = false;
} else {
ret = true;
}
}
ret = ret || FLAGS_layout_transform_dump.empty();
return ret;
}
std::shared_ptr<OptionBase> GoptLayoutOption::create_option() {
static std::shared_ptr<GoptLayoutOption> option(new GoptLayoutOption);
if (GoptLayoutOption::is_valid()) {
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
}
}
void GoptLayoutOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
CONFIG_MODEL_FUN;
}
DEFINE_string(
layout_transform, "",
"Enable global layout transform optimization for computing graph. User should "
"specify the device target for the optimization, and a series of passes will "
"be applied on the computing graph. The passes will benchmark the elapsed time "
"of operators on different tensor layouts, and select fastest implementation "
"for the operators. The optimization process will take some time. The default "
"target is unspec, which all the available for operators will be profiled. So "
"the optimize time will be longer.");
DEFINE_string(
layout_transform_dump, "",
"The computing graph after global layout transform will be dumped to the given "
"file path.");
REGIST_OPTION_CREATOR(gopt_layout, lar::GoptLayoutOption::create_option);
/**
* \file lite/load_and_run/src/options/layout_trans_options.h
*
* This file is part of MegEngine, a deep learning framework developed by
* Megvii.
*
* \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
*/
#pragma once
#include <gflags/gflags.h>
#include "megbrain/gopt/inference.h"
#include "models/model.h"
#include "option_base.h"
DECLARE_string(layout_transform);
DECLARE_string(layout_transform_dump);
namespace lar {
class GoptLayoutOption final : public OptionBase {
public:
//! get condition for construct FastRunOption
static bool is_valid();
//! creat option using condition from cmdline args
static std::shared_ptr<OptionBase> create_option();
//! configure model for different runtime_param
void config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
//! get options name for quickly search
std::string option_name() const override { return m_option_name; }
private:
GoptLayoutOption();
//! config template for different model
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>) {}
bool layout_transform;
std::string m_option_name;
std::string layout_transform_dump_file;
mgb::gopt::GraphTuningOptions::Target layout_transform_target;
};
} // namespace lar
......@@ -93,4 +93,4 @@ DEFINE_bool(share_param_mem, false, "load model from shared memeory");
REGIST_OPTION_CREATOR(run_strategy, lar::StrategyOption::create_option);
REGIST_OPTION_CREATOR(run_testcase, lar::TestcaseOption::create_option);
\ No newline at end of file
REGIST_OPTION_CREATOR(run_testcase, lar::TestcaseOption::create_option);
......@@ -60,6 +60,9 @@ void NormalStrategy::run_subline() {
m_runtime_param.stage = RunStage::AFTER_MODEL_LOAD;
stage_config_model();
m_runtime_param.stage = RunStage::GLOBAL_OPTIMIZATION;
stage_config_model();
m_runtime_param.stage = RunStage::BEFORE_OUTSPEC_SET;
stage_config_model();
......@@ -164,4 +167,4 @@ void NormalStrategy::run() {
mgb_assert(false, "--thread must input a positive number!!");
}
//! execute before run
}
\ No newline at end of file
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册