diff --git a/src/gopt/impl/layout_transform_pass.cpp b/src/gopt/impl/layout_transform_pass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d994cdcc3ffc00087c2bc5478f98d131639e9ac5 --- /dev/null +++ b/src/gopt/impl/layout_transform_pass.cpp @@ -0,0 +1,169 @@ +/** + * \file src/gopt/impl/layout_transform_pass.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./opr_format_modifier.h" +#include "./utils.h" +#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/opr/dnn/pooling.h" +#include "megbrain/opr/imgproc.h" +#include "megbrain/serialization/sereg.h" + +using namespace mgb; +using namespace gopt; +using namespace cg; + +/* =================== LayoutTransformPass ======================*/ +void LayoutTransformPass::apply(OptState& opt) const { + opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^ + VarReplaceCheckFlag::CHECK_SHAPE); + SubGraphExtractor extractor(m_ctx->opr_list()); + auto partitions = extractor.extract(opt.graph().endpoint_vars()); + + using Solution = SolverBase::Solution; + Solution solution; + ThinHashSet endpoint_vars; + for (auto&& partition : partitions) { + if (solution.empty()) { + solution = m_solver->solve(Problem(partition, *m_ctx)); + } else { + auto s = m_solver->solve(Problem(partition, *m_ctx)); + for (auto&& kv : s) + solution.insert({kv.first, kv.second}); + } + for (auto&& o : partition.output()) { + endpoint_vars.insert(o); + } + } + + auto&& opr_configs = m_ctx->opr_configs(); + auto&& base_fmt = m_ctx->attribute().base_tensor_formats; + auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; + ThinHashMap var2fmts; + static ThinHashSet format_aware_oprs = { +#define cb(_Opr) opr::_Opr::typeinfo(), + FOREACH_FORMAT_AWARE_OPR(cb) +#undef cb + }; + auto rewriter = opt.graph().make_rewriter(); + auto on_opr = [this, &opr_configs, &base_fmt, &reformat_attribute, + &rewriter, &solution, &var2fmts, + &endpoint_vars](OperatorNodeBase* opr) { + auto it = solution.find(opr); + if (it != solution.end()) { + auto opr_fmt = it->second; + auto find = opr_configs.find(opr->dyn_typeinfo()); + Maybe fmtcfg = None; + if (find != opr_configs.end()) { + fmtcfg = (*find->second.at(opr_fmt))(opr); + } + VarNodeArray new_inp; + size_t nr_inps = opr->input().size(); + TensorFormats out_fmt; + if (fmtcfg.valid()) { + nr_inps = std::min(fmtcfg.val().input_tensor_formats.size(), + nr_inps); + out_fmt = fmtcfg.val().output_tensor_formats[0]; + } else { + out_fmt = opr_format_to_tensor_formats(opr_fmt); + } + new_inp.resize(nr_inps); + for (size_t i = 0; i < nr_inps; ++i) { + auto&& var = opr->input(i); + auto&& new_var = rewriter.get_var(var); + auto find = var2fmts.find(new_var); + TensorFormats from; + if (find == var2fmts.end()) { + from = base_fmt; + } else { + from = find->second; + } + auto to = fmtcfg.valid() + ? fmtcfg.val().input_tensor_formats[i] + : opr_format_to_tensor_formats(opr_fmt); + bool is_parameter = + fmtcfg.valid() && fmtcfg.val().input_tensor_types[i] == + TensorType::WEIGHT; + ReformatManager::ReformatImpl reformat; + ReformatManager::ReformatKey key{from, to, reformat_attribute, + var->dtype().enumv(), + var->dtype().enumv()}; + if (is_parameter) { + auto aligned_desc = make_aligned_desc(base_fmt, out_fmt); + reformat = ReformatManager::instance() + .auto_aligned_reformat_weight( + var, key, aligned_desc); + } else { + reformat = ReformatManager::instance() + .auto_aligned_reformat_featrue( + var, base_fmt, key); + } + if (from != to && !new_var->shape().is_scalar()) + new_var = reformat({new_var}); + new_inp[i] = new_var; + } + VarNode* new_out; + if (format_aware_oprs.count(opr->dyn_typeinfo()) > 0) { + new_out = intl::modify_opr_format(opr_fmt, new_inp, opr); + } else { + new_out = serialization::copy_opr_shallow(*opr, new_inp, + opr->config()) + ->output(0); + } + if (endpoint_vars.count(opr->output(0)) && out_fmt != base_fmt) { + ReformatManager::ReformatKey key{ + out_fmt, base_fmt, reformat_attribute, + opr->output(0)->dtype().enumv(), + opr->output(0)->dtype().enumv()}; + auto reformat = ReformatManager::instance() + .auto_aligned_reformat_featrue( + opr->output(0), base_fmt, key); + new_out = reformat({new_out}); + var2fmts[new_out] = base_fmt; + } else { + var2fmts[new_out] = out_fmt; + } + auto &&out0 = opr->output(), + &&out1 = new_out->owner_opr()->output(); + mgb_assert(opr->usable_output().size() == + new_out->owner_opr()->usable_output().size(), + "bad opr replace: src=%s{%s} dst=%s{%s}, " + "src.size=%zu " + "dst.size=%zu", + opr->cname(), opr->dyn_typeinfo()->name, + new_out->owner_opr()->cname(), + new_out->owner_opr()->dyn_typeinfo()->name, out0.size(), + out1.size()); + for (size_t i = 0; i < out0.size(); ++i) { + if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) { + mgb_assert(!out1[i]->contain_flag( + VarNode::Flag::VOLATILE_CONTENT)); + auto src = out0[i]; + auto dst = out1[i]; + rewriter.replace_var( + src, dst, + mgb_cstr_log(ssprintf("replace opr(%s) to new opr " + "format(%s)", + opr->cname(), + opr_format_to_string(opr_fmt)) + .c_str())); + } + } + } else { + auto new_opr = rewriter.auto_replace_outputs(opr); + var2fmts[new_opr->output(0)] = base_fmt; + } + }; + opt.graph().iter(on_opr); + rewriter.apply_inplace(); +} + +// vim: syntax=cpp.doxygen diff --git a/src/gopt/test/layout_transform_pass.cpp b/src/gopt/test/layout_transform_pass.cpp new file mode 100644 index 0000000000000000000000000000000000000000..575761b8b30bb20212e1ecb559f5a90e67c7dd74 --- /dev/null +++ b/src/gopt/test/layout_transform_pass.cpp @@ -0,0 +1,303 @@ +/** + * \file src/gopt/test/layout_transform_pass.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "./helper.h" +#include "megbrain/gopt/global_layout_transform.h" +#include "megbrain/gopt/inference.h" +#include "megbrain/opr/dnn/pooling.h" +#include "megbrain/opr/imgproc.h" +#include "megbrain/opr/nn_int.h" +#include "megbrain/plugin/profiler.h" +#include "megbrain/serialization/serializer.h" + +using namespace mgb; +using namespace gopt; +using namespace serialization; + +#if MGB_CUDA +TEST(TestLayoutTransform, Feature) { + auto inp_file = InputFile::make_fs("./feat.mdl"); + + auto format = GraphLoader::identify_graph_dump_format(*inp_file); + ASSERT_TRUE(format.valid()); + auto loader = GraphLoader::make(std::move(inp_file), format.val()); + + GraphLoader::LoadConfig load_config; + load_config.comp_graph = ComputingGraph::make(); + auto&& graph_opt = load_config.comp_graph->options(); + graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); + graph_opt.graph_opt.enable_fuse_conv_bias_with_z(); + auto ret = loader->load(load_config, false); + + using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; + S strategy = S::PROFILE; + gopt::modify_opr_algo_strategy_inplace({ret.output_var_list}, strategy); + + using OprFormat = LayoutTransformContext::OprFormat; + using OprList = LayoutTransformContext::OprList; + using ReformatAttribute = LayoutTransformContext::ReformatAttribute; + using Attribute = LayoutTransformContext::Attribute; + OprList opr_list = { + opr::ConvBiasForward::typeinfo(), + opr::ElemwiseMultiType::typeinfo(), + opr::Elemwise::typeinfo(), + opr::TypeCvt::typeinfo(), + opr::PoolingForward::typeinfo(), + opr::WarpPerspectiveForward::typeinfo(), + }; + SmallVector available_tensor_formats = { + TensorFormats::NCHWc4, TensorFormats::NCHWc32, + TensorFormats::CHWNc4}; + Attribute attribute = {OprFormat::NCHW4, TensorFormats::NCHWc4, + ReformatAttribute::DEFAULT}; + auto ctx = std::make_unique( + std::move(opr_list), std::move(available_tensor_formats), + attribute); + ctx->add_opr_config(opr::ConvBiasForward::typeinfo(), + {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4}) + .add_opr_config( + opr::PoolingForward::typeinfo(), + {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4}) + .add_opr_config(opr::WarpPerspectiveForward::typeinfo(), + OprFormat::NCHW4); + auto profiler = ProfilerBase::make_profiler(); + auto filter = [](const GraphPartition& partition) { + auto has_nchw4_conv = false; + for (auto&& opr : partition.all_oprs()) { + if (opr->dyn_typeinfo() == opr::ConvBiasForward::typeinfo()) { + auto& conv = opr->cast_final_safe(); + if (conv.param().format == + LayoutTransformContext::OprFormat::NCHW4) { + has_nchw4_conv = true; + break; + } + } + } + return has_nchw4_conv; + }; + std::unique_ptr solver{new DynamicProgrammingSolver( + std::move(profiler), std::move(filter))}; + auto new_out_vars = gopt::GraphOptimizer{} + .add_pass() + .add_pass() + .add_pass( + std::move(ctx), std::move(solver)) + .add_pass() + .add_pass(FuseNCHW4Int8Preprocess::make()) + .add_pass() + .add_pass() + .add_pass() + .apply(ret.output_var_list) + .endpoint_vars(); + auto dumper = GraphDumper::make(OutputFile::make_fs("model_opt.mgb")); + dumper->dump({new_out_vars}); +} + +TEST(TestLayoutTransform, Detection) { + auto inp_file = InputFile::make_fs("./det.mdl"); + static const char* magic = "mgbteset0"; + size_t skip_size = sizeof(magic) + sizeof(uint32_t); + char skip[skip_size]; + inp_file->read(skip, skip_size); + + auto format = GraphLoader::identify_graph_dump_format(*inp_file); + ASSERT_TRUE(format.valid()); + auto loader = GraphLoader::make(std::move(inp_file), format.val()); + + GraphLoader::LoadConfig load_config; + load_config.comp_graph = ComputingGraph::make(); + auto&& graph_opt = load_config.comp_graph->options(); + graph_opt.graph_opt.enable_fuse_conv_bias_nonlinearity(); + graph_opt.graph_opt.enable_fuse_conv_bias_with_z(); + auto ret = loader->load(load_config, false); + + using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; + S strategy = S::PROFILE; + gopt::modify_opr_algo_strategy_inplace({ret.output_var_list}, strategy); + + using OprFormat = LayoutTransformContext::OprFormat; + using OprList = LayoutTransformContext::OprList; + using ReformatAttribute = LayoutTransformContext::ReformatAttribute; + using Attribute = LayoutTransformContext::Attribute; + OprList opr_list = { + opr::ConvBiasForward::typeinfo(), + opr::ConvolutionForward::typeinfo(), + opr::ConvolutionBackwardData::typeinfo(), + opr::ElemwiseMultiType::typeinfo(), + opr::Elemwise::typeinfo(), + opr::TypeCvt::typeinfo(), + opr::PoolingForward::typeinfo(), + opr::WarpPerspectiveForward::typeinfo(), + }; + SmallVector available_tensor_formats = { + TensorFormats::NCHW, TensorFormats::NHWC, + TensorFormats::NCHWc4, TensorFormats::NCHWc32, + TensorFormats::NCHWc64, TensorFormats::CHWNc4}; + Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, + ReformatAttribute::DEFAULT}; + auto ctx = std::make_unique( + std::move(opr_list), std::move(available_tensor_formats), + attribute); + ctx->add_opr_config( + opr::ConvBiasForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, + OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) + .add_opr_config(opr::ConvolutionForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW4}) + .add_opr_config(opr::ConvolutionBackwardData::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW4}) + .add_opr_config( + opr::PoolingForward::typeinfo(), + {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, + OprFormat::NCHW64, OprFormat::CHWN4}) + .add_opr_config( + opr::WarpPerspectiveForward::typeinfo(), + {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); + + auto profiler = ProfilerBase::make_profiler(); + std::unique_ptr solver{ + new DynamicProgrammingSolver(std::move(profiler))}; + auto new_out_vars = gopt::GraphOptimizer{} + .add_pass( + std::move(ctx), std::move(solver)) + .add_pass() + .add_pass(FuseNCHW4Int8Preprocess::make()) + .add_pass() + .add_pass() + .add_pass() + .apply(ret.output_var_list) + .endpoint_vars(); + using OutputSpecItem = cg::ComputingGraph::OutputSpecItem; + std::vector outs(new_out_vars.size()); + for (size_t i = 0; i < new_out_vars.size(); ++i) { + auto cb = [](DeviceTensorND& /* d */) {}; + outs[i] = std::make_pair(new_out_vars[i], cb); + } + GraphProfiler gprof{load_config.comp_graph.get()}; + auto func = load_config.comp_graph->compile(outs); + for (size_t i = 0; i < 10; ++i) + func->execute(); + func->wait(); + gprof.to_json_full(func.get())->writeto_fpath(output_file("det.json")); +} + +TEST(TestLayoutTransform, DetectionHead) { + REQUIRE_GPU(1); + auto cn = CompNode::load("gpu0"); + cn.activate(); + REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5); + + constexpr size_t N = 16, C = 3, H = 768, W = 1280; + HostTensorGenerator gen; + + auto graph = ComputingGraph::make(); + auto h2d = opr::Host2DeviceCopy::make(*graph, gen({N, C, H, W}, cn)); + auto data = opr::TypeCvt::make(h2d, dtype::Float32()); + auto sub_128 = data + (-128); + auto x = opr::TypeCvt::make(sub_128, dtype::QuantizedS8(1.f)); + auto mkcvar = [&](const char* name, const TensorShape& shp, + const DType& dtype) { + return opr::TypeCvt::make( + opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) + .rename(name), + dtype); + }; + auto w = mkcvar("w", {16, 3, 3, 3}, dtype::QuantizedS8(1.f)); + auto b = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(1.f)); + opr::ConvBias::Param param; + param.format = opr::ConvBias::Param::Format::NCHW; + param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU; + param.stride_h = param.stride_w = 2; + param.pad_h = param.pad_w = 1; + auto conv_1 = opr::ConvBias::make( + x, w, b, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f))); + conv_1 = opr::TypeCvt::make( + conv_1, dtype::Quantized4Asymm(1.f, static_cast(8))); + auto w1 = mkcvar("w1", {16, 16, 3, 3}, dtype::QuantizedS4(1.f)); + auto b1 = mkcvar("b1", {1, 16, 1, 1}, dtype::QuantizedS32(1.f)); + auto y = opr::ConvBias::make(conv_1, w1, b1, param, {}, + OperatorNodeConfig(dtype::Quantized4Asymm( + 1.f, static_cast(8)))); + + using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy; + S strategy = S::PROFILE; + gopt::modify_opr_algo_strategy_inplace({y}, strategy); + + using OprFormat = LayoutTransformContext::OprFormat; + using OprList = LayoutTransformContext::OprList; + using ReformatAttribute = LayoutTransformContext::ReformatAttribute; + using Attribute = LayoutTransformContext::Attribute; + OprList opr_list = { + opr::ConvBiasForward::typeinfo(), + opr::ConvolutionForward::typeinfo(), + opr::ConvolutionBackwardData::typeinfo(), + opr::ElemwiseMultiType::typeinfo(), + opr::Elemwise::typeinfo(), + opr::TypeCvt::typeinfo(), + opr::PoolingForward::typeinfo(), + opr::WarpPerspectiveForward::typeinfo(), + }; + SmallVector available_tensor_formats = { + TensorFormats::NCHW, TensorFormats::NHWC, + TensorFormats::NCHWc4, TensorFormats::NCHWc32, + TensorFormats::NCHWc64, TensorFormats::CHWNc4}; + Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, + ReformatAttribute::DEFAULT}; + auto ctx = std::make_unique( + std::move(opr_list), std::move(available_tensor_formats), + attribute); + ctx->add_opr_config( + opr::ConvBiasForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, + OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4}) + .add_opr_config(opr::ConvolutionForward::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW4}) + .add_opr_config(opr::ConvolutionBackwardData::typeinfo(), + {OprFormat::NCHW, OprFormat::NCHW4}) + .add_opr_config( + opr::PoolingForward::typeinfo(), + {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC, + OprFormat::NCHW64, OprFormat::CHWN4}) + .add_opr_config( + opr::WarpPerspectiveForward::typeinfo(), + {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); + + auto profiler = ProfilerBase::make_profiler(); + std::unique_ptr solver{ + new DynamicProgrammingSolver(std::move(profiler))}; + auto new_out_vars = gopt::GraphOptimizer{} + .add_pass( + std::move(ctx), std::move(solver)) + .add_pass() + .add_pass(FuseNCHW4Int8Preprocess::make()) + .add_pass() + .add_pass() + .add_pass() + .apply(SymbolVarArray{y}) + .endpoint_vars(); + using OutputSpecItem = cg::ComputingGraph::OutputSpecItem; + std::vector outs(new_out_vars.size()); + for (size_t i = 0; i < new_out_vars.size(); ++i) { + auto cb = [](DeviceTensorND& /* d */) {}; + outs[i] = std::make_pair(new_out_vars[i], cb); + } + GraphProfiler gprof{graph.get()}; + auto func = graph->compile(outs); + for (size_t i = 0; i < 10; ++i) + func->execute(); + func->wait(); + gprof.to_json_full(func.get())->writeto_fpath(output_file("det_head.json")); +} + +#endif + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}