提交 fac67e7c 编写于 作者: M Megvii Engine Team

feat(gopt): support nchw44 global pooling with fuse_grain

GitOrigin-RevId: 4c43a149f8214aae7f48ae72bbd6928c1e18f0f5
上级 8461c8d8
#include "src/arm_common/adaptive_pooling/opr_impl.h"
#include "src/common/opr_delegate.h"
#include "src/naive/handle.h"
namespace megdnn {
namespace arm_common {
void AdaptivePoolingImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
auto adapt_fwd = [=]() {
auto opr = inplace_cpu_handle()->create_operator<PoolingForward>();
opr->param() = deduce_pooling_param(src.layout, dst.layout);
opr->exec(src, dst, workspace);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(adapt_fwd());
return;
}
size_t AdaptivePoolingImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
auto opr = inplace_cpu_handle()->create_operator<PoolingForward>();
opr->param() = deduce_pooling_param(src, dst);
auto need_size = opr->get_workspace_in_bytes(src, dst);
return need_size;
}
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
#pragma once
#include "megdnn/oprs.h"
namespace megdnn {
namespace arm_common {
class AdaptivePoolingImpl final : public AdaptivePoolingForward {
public:
using AdaptivePoolingForward::AdaptivePoolingForward;
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) override;
};
} // namespace arm_common
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -2,6 +2,7 @@
#include "src/arm_common/handle.h"
#include "src/arm_common/adaptive_pooling/opr_impl.h"
#include "src/arm_common/conv_bias/opr_impl.h"
#include "src/arm_common/convolution/opr_impl.h"
#include "src/arm_common/cvt_color/opr_impl.h"
......@@ -45,6 +46,7 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(ConvolutionBackwardData)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(RNNCell)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTMCell)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(LSTM)
MEGDNN_SPECIALIZE_CREATE_OPERATOR(AdaptivePooling)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wpragmas"
......
......@@ -2,6 +2,7 @@
#include "src/arm_common/pooling/algo.h"
#include "src/common/algo_chooser.h"
#include "src/common/metahelper.h"
#include "src/common/opr_delegate.h"
using namespace megdnn;
using namespace arm_common;
......@@ -48,10 +49,72 @@ public:
};
PoolingImpl::AlgoPack PoolingImpl::sm_algo_pack;
namespace {
TensorLayout merge_hw_layout(TensorLayout src) {
src.ndim -= 1;
src.shape[2] = src.shape[2] * src.shape[3];
src.stride[2] = src.stride[3];
for (size_t i = 3; i < src.ndim; ++i) {
src.shape[i] = src.shape[i + 1];
src.stride[i] = src.stride[i + 1];
}
return src;
}
std::pair<TensorND, TensorND> get_gloabl_pooling_reduce_tensor(
const TensorND& src, const TensorND& dst) {
auto reduce_src_layout = merge_hw_layout(src.layout);
auto reduce_dst_layout = merge_hw_layout(dst.layout);
return std::make_pair<TensorND, TensorND>(
{src.raw_ptr(), reduce_src_layout}, {dst.raw_ptr(), reduce_dst_layout});
}
std::unique_ptr<Reduce> get_global_pooling_reduce_opr(
Handle* handle, const PoolingImpl::PoolingKernSizeParam& param) {
std::unique_ptr<Reduce> opr;
if (handle) {
opr = handle->create_operator<Reduce>();
} else {
opr = inplace_cpu_handle()->create_operator<Reduce>();
}
param::Reduce reduce_param;
reduce_param.axis = 2;
if (param.mode == PoolingImpl::Param::Mode::MAX) {
reduce_param.mode = param::Reduce::Mode::MAX;
} else {
megdnn_assert(
param.mode == PoolingImpl::Param::Mode::AVERAGE ||
param.mode == PoolingImpl::Param::Mode::AVERAGE_COUNT_EXCLUDE_PADDING);
reduce_param.mode = param::Reduce::Mode::MEAN;
}
opr->param() = reduce_param;
return opr;
}
bool is_global_pooling_reduce(PoolingImpl::PoolingKernSizeParam& param) {
bool fmt_ok = param.format == PoolingImpl::Param::Format::NCHW ||
param.format == PoolingImpl::Param::Format::NCHW44 ||
param.format == PoolingImpl::Param::Format::NCHW88;
bool size_ok = param.filter[0] == param.isz[0] && param.filter[1] == param.isz[1] &&
param.padding[0] == 0 && param.padding[1] == 0 &&
param.osz[0] == 1 && param.osz[1] == 1;
bool dtype_ok = param.src_type == param.dst_type &&
param.src_type.enumv() != DTypeEnum::Int8;
return fmt_ok && size_ok && dtype_ok;
}
} // namespace
size_t PoolingImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
auto param = make_pooling_kern_szie_param(this, src, dst);
bool fwd_reduce = is_global_pooling_reduce(param);
if (fwd_reduce) {
TensorND src_tensor{nullptr, src};
TensorND dst_tensor{nullptr, dst};
auto reduce_tensor = get_gloabl_pooling_reduce_tensor(src_tensor, dst_tensor);
auto&& opr = get_global_pooling_reduce_opr(nullptr, param);
auto reduce_need = opr->get_workspace_in_bytes(
reduce_tensor.first.layout, reduce_tensor.second.layout);
return reduce_need;
}
auto algo = get_algorithm(this, src, dst);
if (!is_fallback_algo(algo)) {
size_t arm_common_workspace = 0;
......@@ -93,6 +156,18 @@ void PoolingImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, workspace.size);
auto param = make_pooling_kern_param(this, src, dst, workspace);
bool fwd_reduce = is_global_pooling_reduce(param);
if (fwd_reduce) {
auto global_pooling_fwd = [=]() {
auto reduce_tensor = get_gloabl_pooling_reduce_tensor(src, dst);
auto&& opr = get_global_pooling_reduce_opr(nullptr, param);
opr->exec(reduce_tensor.first, reduce_tensor.second, workspace);
};
MEGDNN_DISPATCH_CPU_KERN_OPR(global_pooling_fwd());
return;
}
auto algo = get_algorithm(this, src.layout, dst.layout);
if (!is_fallback_algo(algo)) {
algo->exec(param);
......
......@@ -8,7 +8,9 @@ param::Pooling AdaptivePoolingBase::deduce_pooling_param(
const TensorLayout& src, const TensorLayout& dst) {
auto param_format = param().format;
size_t IH, IW, OH, OW;
if (param_format == param::AdaptivePooling::Format::NCHW) {
if (param_format == param::AdaptivePooling::Format::NCHW ||
param_format == param::AdaptivePooling::Format::NCHW44 ||
param_format == param::AdaptivePooling::Format::NCHW88) {
IH = src.shape[2];
IW = src.shape[3];
OH = dst.shape[2];
......@@ -19,7 +21,8 @@ param::Pooling AdaptivePoolingBase::deduce_pooling_param(
OH = dst.shape[1];
OW = dst.shape[2];
} else {
megdnn_throw("AdaptivePooling only support NCHW or NHWC format");
megdnn_throw(
"AdaptivePooling only support NCHW or NHWC or NCHW44 or NCHW88 format");
}
param::Pooling ret;
......
......@@ -140,7 +140,9 @@ void PoolingForward::check_exec(
const TensorLayout& src, const TensorLayout& dst, size_t workspace_in_bytes) {
check_layout_fwd(src, dst);
auto required_workspace_in_bytes = get_workspace_in_bytes(src, dst);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
megdnn_assert(
workspace_in_bytes >= required_workspace_in_bytes, "need %zu, get %zu",
required_workspace_in_bytes, workspace_in_bytes);
}
void PoolingBackward::check_exec(
......
......@@ -6,11 +6,17 @@
namespace megdnn {
namespace naive {
size_t AdaptivePoolingForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst) {
auto opr = inplace_cpu_handle(2)->create_operator<PoolingForward>();
opr->param() = deduce_pooling_param(src, dst);
auto need_size = opr->get_workspace_in_bytes(src, dst);
return need_size;
}
void AdaptivePoolingForwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst, _megdnn_workspace workspace) {
MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle()), {
auto opr = inplace_cpu_handle()->create_operator<PoolingForward>();
auto opr = inplace_cpu_handle(2)->create_operator<PoolingForward>();
opr->param() = deduce_pooling_param(src.layout, dst.layout);
opr->exec(src, dst, workspace);
});
......@@ -20,7 +26,7 @@ void AdaptivePoolingBackwardImpl::exec(
_megdnn_tensor_in src, _megdnn_tensor_in dst, _megdnn_tensor_in diff,
_megdnn_tensor_out grad, _megdnn_workspace workspace) {
MEGDNN_DISPATCH_CPU_KERN(static_cast<naive::HandleImpl*>(handle()), {
auto opr = inplace_cpu_handle()->create_operator<PoolingBackward>();
auto opr = inplace_cpu_handle(2)->create_operator<PoolingBackward>();
opr->param() = deduce_pooling_param(src.layout, dst.layout);
opr->exec(src, dst, diff, grad, workspace);
});
......@@ -29,7 +35,7 @@ void AdaptivePoolingBackwardImpl::exec(
size_t AdaptivePoolingBackwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& dst, const TensorLayout& diff,
const TensorLayout& grad) {
auto opr = inplace_cpu_handle()->create_operator<PoolingBackward>();
auto opr = inplace_cpu_handle(2)->create_operator<PoolingBackward>();
opr->param() = deduce_pooling_param(src, dst);
return opr->get_workspace_in_bytes(src, dst, diff, grad);
}
......
......@@ -11,9 +11,7 @@ public:
void exec(
_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override {
return 0;
}
size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&) override;
};
class AdaptivePoolingBackwardImpl : public AdaptivePoolingBackward {
......
#include "test/arm_common/fixture.h"
#include "megdnn/tensor_iter.h"
#include "src/common/utils.h"
#include "test/common/adaptive_pooling.h"
#include "test/common/benchmarker.h"
#include "test/common/checker.h"
namespace megdnn {
namespace test {
TEST_F(ARM_COMMON, ADAPTIVE_POOLING_FORWARD_NCHW44) {
auto args = adaptive_pooling::get_args_nchw44();
Checker<AdaptivePooling> checker(handle());
checker.set_epsilon(1e-4);
for (DType dtype : {(DType)dtype::Float32(), (DType)dtype::QuantizedS8(1.0)})
for (auto&& arg : args) {
auto param = arg.param;
auto src = arg.ishape;
auto dst = arg.oshape;
checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec(
TensorShapeArray{src, dst, {}});
}
}
TEST_F(ARM_COMMON, ADAPTIVE_POOLING_FORWARD) {
auto args = adaptive_pooling::get_args();
Checker<AdaptivePooling> checker(handle());
checker.set_epsilon(1e-4);
for (DType dtype : {(DType)dtype::Float32(), (DType)dtype::QuantizedS8(1.0)})
for (auto&& arg : args) {
auto param = arg.param;
auto src = arg.ishape;
auto dst = arg.oshape;
checker.set_param(param).set_dtype(0, dtype).set_dtype(1, dtype).exec(
TensorShapeArray{src, dst, {}});
}
}
#if MEGDNN_WITH_BENCHMARK
namespace {
void benchmark_globalpooling_nchw44_fp32(Handle* handle) {
using Param = param::AdaptivePooling;
auto run = [&](size_t n, size_t c, size_t h, size_t w, Param::Mode mode) {
Param param;
param.format = Param::Format::NCHW;
param.mode = mode;
TensorShape nchw_shape = {n, c, h, w};
TensorShape nchw_dst_shape = {n, c, 1, 1};
TensorShape nchw44_shape = {n, c / 4, h, w, 4};
TensorShape nchw44_dst_shape = {n, c / 4, 1, 1, 4};
TensorLayout dst_layout;
float calc_amount = n * c * h * w;
Benchmarker<AdaptivePooling> benchmarker_float_nchw(handle);
Benchmarker<AdaptivePooling> benchmarker_float_nchw44(handle);
Benchmarker<AdaptivePooling> benchmarker_int_nchw44(handle);
size_t RUN = 500;
auto t1 = benchmarker_float_nchw.set_display(false)
.set_times(RUN)
.set_param(param)
.exec({nchw_shape, nchw_dst_shape});
param.format = Param::Format::NCHW44;
auto t2 = benchmarker_int_nchw44.set_display(false)
.set_times(RUN)
.set_param(param)
.execl({{nchw44_shape, dtype::QuantizedS8(1.0)},
{nchw44_dst_shape, dtype::QuantizedS8(1.0)}});
auto t3 = benchmarker_float_nchw44.set_display(false)
.set_times(RUN)
.set_param(param)
.exec({nchw44_shape, nchw44_dst_shape});
printf("{%zu %zu %zu %zu} \n"
"nchw_fp32={%.3f ms, %.3f Mflops}, "
"nchw44_int={%.3f ms, %.3f Mflops}, "
"nchw44_fp32={%.3f ms, %.3f Mflops, speed_up %f}\n\n",
n, c, h, w, t1 / RUN, calc_amount / (t1 / RUN * 1000), t2 / RUN,
calc_amount / (t2 / RUN * 1000), t3 / RUN,
calc_amount / (t3 / RUN * 1000), t1 / t3);
};
run(1, 128, 25, 25, param::AdaptivePooling::Mode::AVERAGE);
}
} // namespace
TEST_F(ARM_COMMON, BENCHMARK_GLOBAL_POOLING_NCHW44_FP32) {
benchmark_globalpooling_nchw44_fp32(handle());
}
#endif
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -40,6 +40,36 @@ inline std::vector<TestArg> get_args() {
return args;
}
inline std::vector<TestArg> get_args_nchw44() {
std::vector<TestArg> args;
using Param = param::AdaptivePooling;
using Mode = param::AdaptivePooling::Mode;
for (size_t i = 36; i < 40; ++i) {
args.emplace_back(
Param{Mode::AVERAGE, Param::Format::NCHW44},
TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, i - 4, i - 2, 4});
args.emplace_back(
Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4},
TensorShape{2, 3, i - 4, i - 2, 4});
args.emplace_back(
Param{Mode::AVERAGE, Param::Format::NCHW44},
TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, 1, 1, 4});
args.emplace_back(
Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4},
TensorShape{2, 3, 1, 1, 4});
}
for (size_t i = 5; i < 10; ++i) {
args.emplace_back(
Param{Mode::AVERAGE, Param::Format::NCHW44},
TensorShape{2, 3, i, i + 1, 4}, TensorShape{2, 3, i - 3, i - 2, 4});
args.emplace_back(
Param{Mode::MAX, Param::Format::NCHW44}, TensorShape{2, 3, i, i + 1, 4},
TensorShape{2, 3, i - 3, i - 2, 4});
}
return args;
}
} // namespace adaptive_pooling
} // namespace test
} // namespace megdnn
......
......@@ -254,7 +254,9 @@ def optimize_for_inference(dest_vars, **kwargs):
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and
inference)
inference
* enable_fuse_grain: fuse grain will be enable by default to fuse grain operator to huge operator, you can disable it.
)
"""
inference_options = GraphOptimizeOptions()
inference_optimize_layout_transform_map = {
......@@ -282,6 +284,8 @@ def optimize_for_inference(dest_vars, **kwargs):
inference_options.fuse_conv_bias_with_z = True
if kwargs.pop("enable_fuse_preprocess", False):
inference_options.fuse_preprocess = True
if kwargs.pop("enable_fuse_grain", True):
inference_options.fuse_grain = True
if kwargs:
raise ValueError("unknown options: %s" % list(kwargs))
......@@ -330,6 +334,8 @@ def deserialize_infer_option(x: int) -> Dict[str, bool]:
ret["enable_fuse_conv_bias_with_z"] = True
if inference_options.fuse_preprocess:
ret["enable_fuse_preprocess"] = True
if inference_options.fuse_grain:
ret["enable_fuse_grain"] = True
return ret
......
......@@ -151,7 +151,9 @@ class Network:
* enable_fuse_conv_bias_with_z: whether to fuse conv_bias with z
input for inference on nvidia backend(this optimization pass will
result in mismatch of the precision of output of training and
inference)
inference
* enable_fuse_grain: fuse grain will be enable by default to fuse grain operator to huge operator, you can disable it.
)
"""
if not isinstance(dest_vars, Sequence):
......@@ -221,7 +223,6 @@ class Network:
logger.warning(
'"output_names" is not supported in Network.dump, rename output vars directly'
)
if optimize_for_inference:
out, optimize_options = G.optimize_for_inference(out, **kwargs)
......
......@@ -292,7 +292,9 @@ void init_graph_rt(py::module m) {
&_OptimizeForInferenceOptions::fuse_preprocess)
.def_readwrite(
"layout_transform",
&_OptimizeForInferenceOptions::layout_transform);
&_OptimizeForInferenceOptions::layout_transform)
.def_readwrite(
"fuse_grain", &_OptimizeForInferenceOptions::fuse_grain);
py::enum_<_LayoutTransform>(GraphOptimizeOptions, "LayoutTransform")
.value("DEFAULT", _LayoutTransform::DEFAULT)
......
......@@ -47,6 +47,7 @@ def test_metadata():
"user_info": {"str": "x", "tensor": x, "module": M.Module, "none": None},
"graph_modified": True, # True: Network.dump
"optimized_for_inference": True,
"enable_fuse_grain": True,
"enable_nchw4": True,
"enable_ioc16": True,
}
......
......@@ -76,6 +76,13 @@ void ModelMdl::make_output_spec() {
}
m_asyc_exec = m_load_result.graph_compile(m_output_spec);
auto new_output_vars = m_asyc_exec->get_output_vars();
mgb::cg::SymbolVarArray symbol_var_array;
symbol_var_array.reserve(new_output_vars.size());
for (auto output_var : new_output_vars) {
symbol_var_array.emplace_back(output_var);
}
m_load_result.output_var_list = symbol_var_array;
}
std::shared_ptr<mgb::serialization::GraphLoader>& ModelMdl::reset_loader(
......
......@@ -56,6 +56,43 @@ void PackModelOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
CONFIG_MODEL_FUN;
}
///////////////////// RawModelOption //////////////////////////
std::shared_ptr<OptionBase> RawModelOption::create_option() {
static std::shared_ptr<RawModelOption> option(new RawModelOption);
if (RawModelOption::is_valid()) {
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
}
}
RawModelOption::RawModelOption() {
m_option_name = "raw_model";
if (!FLAGS_model_dump.empty())
model_dump = FLAGS_model_dump;
}
bool RawModelOption::is_valid() {
return !FLAGS_model_dump.empty();
}
void RawModelOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
CONFIG_MODEL_FUN;
}
template <typename ModelImpl>
void RawModelOption::config_model_internel(
RuntimeParam& runtime_param, std::shared_ptr<ModelImpl> model) {
if (runtime_param.stage == RunStage::AFTER_MODEL_RUNNING) {
auto model_data = model->get_model_data();
std::ofstream ofs(model_dump, std::ios::binary);
if (!ofs.is_open()) {
mgb_log_warn("can not open file %s to write model\n", model_dump.c_str());
return;
}
ofs.write((char*)model_data.data(), model_data.size());
ofs.close();
mgb_log_warn("success write model to %s\n", model_dump.c_str());
}
}
////////////////////// PackModel gflags ////////////////////////
......@@ -79,4 +116,8 @@ DEFINE_string(
"https://megengine.megvii-inc.com/user-guide/deployment/lite/advance/"
"pack-lite-model.html for more details.");
REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option);
\ No newline at end of file
///////////////////// RawModel gflags ///////////////////////////
DEFINE_string(model_dump, "", "The output file path of raw model.");
REGIST_OPTION_CREATOR(pack_model, lar::PackModelOption::create_option);
REGIST_OPTION_CREATOR(dump_model, lar::RawModelOption::create_option);
\ No newline at end of file
......@@ -3,7 +3,7 @@
#include "megbrain/graph/operator_node.h"
#include "models/model.h"
#include "option_base.h"
DECLARE_string(model_dump);
DECLARE_string(packed_model_dump);
DECLARE_string(pack_info_json);
DECLARE_string(pack_cache);
......@@ -36,4 +36,22 @@ private:
std::string pack_model_cryption;
bool is_fast_run_cache = true;
};
class RawModelOption : public OptionBase {
public:
static bool is_valid();
static std::shared_ptr<OptionBase> create_option();
void config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
std::string option_name() const override { return m_option_name; }
private:
RawModelOption();
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>);
std::string m_option_name;
std::string model_dump;
};
} // namespace lar
......@@ -124,6 +124,60 @@ void WeightPreprocessOption::config_model(
CONFIG_MODEL_FUN;
}
///////////////////////// fuse grain optimize options ///////////////
bool FuseGrainOption::m_valid;
namespace lar {
template <>
void FuseGrainOption::config_model_internel<ModelLite>(
RuntimeParam& runtime_param, std::shared_ptr<ModelLite>) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
if (m_fuse_grain) {
LITE_THROW("enable fuse-grain optimization not support in lite model");
}
}
}
template <>
void FuseGrainOption::config_model_internel<ModelMdl>(
RuntimeParam& runtime_param, std::shared_ptr<ModelMdl> model) {
if (runtime_param.stage == RunStage::BEFORE_MODEL_LOAD) {
auto&& graph_option = model->get_mdl_config().comp_graph->options();
if (m_fuse_grain) {
mgb_log_warn("enable fuse-grain optimization");
graph_option.graph_opt.enable_fuse_grain();
}
}
}
} // namespace lar
FuseGrainOption::FuseGrainOption() {
m_option_name = "fuse_grain";
m_fuse_grain = FLAGS_fuse_grain;
m_option = {{"fuse_grain", lar::Bool::make(false)}};
std::static_pointer_cast<lar::Bool>(m_option["fuse_grain"])
->set_value(FLAGS_fuse_grain);
}
bool FuseGrainOption::is_valid() {
return true;
}
std::shared_ptr<OptionBase> FuseGrainOption::create_option() {
static std::shared_ptr<FuseGrainOption> option(new FuseGrainOption);
if (FuseGrainOption::is_valid()) {
return std::static_pointer_cast<OptionBase>(option);
} else {
return nullptr;
}
}
void FuseGrainOption::config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) {
m_fuse_grain =
std::static_pointer_cast<lar::Bool>(m_option["fuse_grain"])->get_value();
CONFIG_MODEL_FUN;
}
///// fuse conv bias and nonlinear activation opr optimize options ////////
bool FuseConvBiasNonlinearOption::m_valid;
namespace lar {
......@@ -579,6 +633,7 @@ void TensorRTOption::config_model(
DEFINE_bool(
enable_fuse_preprocess, false,
"Fusion astype | pad_channel | dimshuffle and etc opr from h2d opr");
DEFINE_bool(fuse_grain, false, "Enable fusion grain opr to huge opr");
DEFINE_bool(
weight_preprocess, false,
"Execute operators with weight preprocess, which can optimize the "
......@@ -589,7 +644,7 @@ DEFINE_bool(
"whether to fuse conv+bias+nonlinearity");
DEFINE_bool(
enable_fuse_conv_bias_with_z, false,
"fuse conv,bias (elemwise add),z(elemwise add) into one opr "
"fuse conv, bias (elemwise add), z(elemwise add) into one opr "
"(only support on GPU)");
///////////////////////// graph retrict options /////////////////////////
......@@ -636,6 +691,9 @@ REGIST_OPTION_VALIDATER(fuse_preprocess, lar::FusePreprocessOption::set_valid);
REGIST_OPTION_CREATOR(weight_preprocess, lar::WeightPreprocessOption::create_option);
REGIST_OPTION_VALIDATER(weight_preprocess, lar::WeightPreprocessOption::set_valid);
REGIST_OPTION_CREATOR(disable_fuse_grain, lar::FuseGrainOption::create_option);
REGIST_OPTION_VALIDATER(disable_fuse_grain, lar::FuseGrainOption::set_valid);
REGIST_OPTION_CREATOR(
fuse_conv_bias_nonlinearity, lar::FuseConvBiasNonlinearOption::create_option);
REGIST_OPTION_VALIDATER(
......
......@@ -5,6 +5,7 @@
#include "option_base.h"
DECLARE_bool(enable_fuse_preprocess);
DECLARE_bool(fuse_grain);
DECLARE_bool(weight_preprocess);
DECLARE_bool(enable_fuse_conv_bias_nonlinearity);
DECLARE_bool(enable_fuse_conv_bias_with_z);
......@@ -79,7 +80,31 @@ private:
static bool m_valid;
OptionValMap m_option;
};
///////////////////////// fuse grain options //////////////
class FuseGrainOption final : public OptionBase {
public:
static bool is_valid();
static std::shared_ptr<OptionBase> create_option();
void config_model(
RuntimeParam& runtime_param, std::shared_ptr<ModelBase> model) override;
std::string option_name() const override { return m_option_name; };
static void set_valid(bool val) { m_valid = val; };
OptionValMap* get_option() override { return &m_option; }
private:
FuseGrainOption();
template <typename ModelImpl>
void config_model_internel(RuntimeParam&, std::shared_ptr<ModelImpl>){};
std::string m_option_name;
bool m_fuse_grain;
static bool m_valid;
OptionValMap m_option;
};
/////////////// fuse_conv_bias_nonlinearity optimize options ///////////////
class FuseConvBiasNonlinearOption final : public OptionBase {
public:
......
......@@ -91,7 +91,7 @@ public:
};
class OutputVarsUserData final : public mgb::UserDataContainer::UserData {
MGB_TYPEINFO_OBJ_DECL;
MGB_TYPEINFO_OBJ_DECL_WITH_EXPORT;
private:
VarNodeArray m_output_vars;
......
......@@ -91,6 +91,9 @@ struct GraphCommonOptimizeOptions {
bool weight_preprocess = false;
//! fuse preprocess patten, like astype + pad_channel + dimshuffle
bool fuse_preprocess = false;
//! fuse_grain patten, replace grain ir with huge ir
bool fuse_grain = false;
enum LayoutTransform : uint32_t {
DEFAULT,
NCHW4, ///< compute using NCHW4 tensor format
......@@ -124,6 +127,7 @@ struct GraphCommonOptimizeOptions {
SET(fuse_conv_bias_with_z);
SET(fuse_preprocess);
SET(weight_preprocess);
SET(fuse_grain);
#undef SET
#define SET(_trans, _trans_capital) \
GraphCommonOptimizeOptions& enable_##_trans() { \
......
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megdnn/opr_param_defs.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "../../core/impl/graph/cg_impl.h"
#include "./gopt_helper.h"
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
MIDOUT_DECL(megbrain_folding_global_pooling)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_folding_global_pooling, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using namespace mgb;
using namespace gopt;
/* ==================== FoldingGlobalPoolingPass ================= */
const char* FoldingGlobalPoolingPass::name() const {
return mgb_cstr_log("folding reduce mean pass");
}
void FoldingGlobalPoolingPass::apply(OptState& opt) const {
MIDOUT_B("FoldingGlobalPoolingPass::apply");
FindNext find_tool(opt);
auto rewriter = opt.graph().make_rewriter();
/**
*
* reshape+------>reduce(mean or max)+--->axis_add_remove*n
* ||
* ||
* ||
* \/
* adaptive_pooling(1,1)
*/
auto try_fuse_global_pooling_axis_add = [&rewriter,
&find_tool](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
MGB_MARK_USED_VAR(rewriter);
MGB_MARK_USED_VAR(find_tool);
auto axis_modi = try_cast_as_op<opr::AxisAddRemove>(opr);
CHECK_OR_RETURN(axis_modi);
CHECK_OR_RETURN(find_tool.used_count(axis_modi) <= 1);
auto output_shape = axis_modi->output(0)->shape();
CHECK_OR_RETURN(output_shape.ndim == 4);
CHECK_OR_RETURN(output_shape[2] == output_shape[3] && output_shape[2] == 1);
auto axis_input = axis_modi->input(0)->owner_opr();
auto axis_modi_x = axis_input->try_cast_final<opr::AxisAddRemove>();
auto reduce = axis_input->try_cast_final<opr::Reduce>();
while (axis_modi_x) {
CHECK_OR_RETURN(find_tool.used_count(axis_modi_x) == 1);
auto axis_input_x = axis_modi_x->input(0)->owner_opr();
reduce = axis_input_x->try_cast_final<opr::Reduce>();
axis_modi_x = axis_input_x->try_cast_final<opr::AxisAddRemove>();
}
CHECK_OR_RETURN(reduce);
auto reduce_mode = reduce->param().mode;
CHECK_OR_RETURN(
reduce_mode == opr::Reduce::Param::Mode::MAX ||
reduce_mode == opr::Reduce::Param::Mode::MEAN);
auto reduce_axis = reduce->param().axis;
CHECK_OR_RETURN(reduce_axis == 2)
auto reshape = reduce->input(0)->owner_opr()->try_cast_final<opr::Reshape>();
CHECK_OR_RETURN(reshape);
auto reshape_in_shape = reshape->input(0)->shape();
auto reshape_out_shape = reshape->output(0)->shape();
bool merge_hw =
reshape_out_shape.ndim == 3 && reshape_in_shape.ndim == 4 &&
reshape_in_shape[2] * reshape_in_shape[3] == reshape_out_shape[2];
CHECK_OR_RETURN(merge_hw);
opr::AdaptivePooling::Param param;
if (reduce_mode == opr::Reduce::Param::Mode::MAX) {
param.mode = opr::AdaptivePooling::Param::Mode::MAX;
} else {
mgb_assert(reduce_mode == opr::Reduce::Param::Mode::MEAN);
param.mode = opr::AdaptivePooling::Param::Mode::AVERAGE;
}
auto new_node = opr::AdaptivePooling::make(
rewriter.get_var(reshape->input(0)), {1, 1}, param);
rewriter.replace_var(
axis_modi->output(0), new_node.node(),
mgb_cstr_log("replace reshape+reduce+add_axis -> adaptive pooling"));
return true;
};
/**
*
* reshape+------>reduce(mean or max)+--->dimshuffle(0,1,-1,-1)
* ||
* ||
* ||
* \/
* adaptive_pooling(1,1)
*/
auto try_fuse_global_pooling_dimshuffle = [&rewriter,
&find_tool](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
MGB_MARK_USED_VAR(rewriter);
MGB_MARK_USED_VAR(find_tool);
auto dimshuffle = try_cast_as_op<opr::Dimshuffle>(opr);
CHECK_OR_RETURN(dimshuffle);
auto patten_param = dimshuffle->param();
CHECK_OR_RETURN(patten_param.pattern_len == 4);
auto patten = patten_param.pattern;
CHECK_OR_RETURN(
patten[0] == 0 && patten[1] == 1 && patten[2] == -1 && patten[3] == -1);
auto axis_remove =
dimshuffle->input(0)->owner_opr()->try_cast_final<opr::AxisAddRemove>();
CHECK_OR_RETURN(axis_remove);
auto reduce = axis_remove->input(0)->owner_opr()->try_cast_final<opr::Reduce>();
CHECK_OR_RETURN(reduce);
auto reduce_mode = reduce->param().mode;
CHECK_OR_RETURN(
reduce_mode == opr::Reduce::Param::Mode::MAX ||
reduce_mode == opr::Reduce::Param::Mode::MEAN);
auto reduce_axis = reduce->param().axis;
CHECK_OR_RETURN(reduce_axis == 2)
auto reshape = reduce->input(0)->owner_opr()->try_cast_final<opr::Reshape>();
CHECK_OR_RETURN(reshape);
auto reshape_in_shape = reshape->input(0)->shape();
auto reshape_out_shape = reshape->output(0)->shape();
bool merge_hw =
reshape_out_shape.ndim == 3 && reshape_in_shape.ndim == 4 &&
reshape_in_shape[2] * reshape_in_shape[3] == reshape_out_shape[2];
CHECK_OR_RETURN(merge_hw);
opr::AdaptivePooling::Param param;
if (reduce_mode == opr::Reduce::Param::Mode::MAX) {
param.mode = opr::AdaptivePooling::Param::Mode::MAX;
} else {
mgb_assert(reduce_mode == opr::Reduce::Param::Mode::MEAN);
param.mode = opr::AdaptivePooling::Param::Mode::AVERAGE;
}
auto new_node = opr::AdaptivePooling::make(
rewriter.get_var(reshape->input(0)), {1, 1}, param);
rewriter.replace_var(
dimshuffle->output(0), new_node.node(),
mgb_cstr_log("replace reshape+reduce+dimshuffle -> adaptive pooling"));
return true;
};
auto on_opr = [&try_fuse_global_pooling_axis_add,
&try_fuse_global_pooling_dimshuffle,
&rewriter](OperatorNodeBase* opr) {
if (!try_fuse_global_pooling_axis_add(opr) &&
!try_fuse_global_pooling_dimshuffle(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megdnn/opr_param_defs.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "../../core/impl/graph/cg_impl.h"
#include "./gopt_helper.h"
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
MIDOUT_DECL(megbrain_folding_reduce_mean)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_folding_reduce_mean, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using namespace mgb;
using namespace gopt;
/* ==================== FoldingReduceMeanPass ================= */
const char* FoldingReduceMeanPass::name() const {
return mgb_cstr_log("folding reduce mean pass");
}
void FoldingReduceMeanPass::apply(OptState& opt) const {
MIDOUT_B("FoldingReduceMeanPass::apply");
FindNext find_tool(opt);
auto rewriter = opt.graph().make_rewriter();
/**
* reshape+---------->reduce(axis, sum)+--------->axis_remove+----------->true_div
* | ^
* | |
* +--------------> get_var_shape(axis)+------------>type_cvt(fp32)+-------+
* ||
* ||
* \/
* reshape+-------->reduce(axis, mean)+--------->axis_remove
*
*
**/
auto try_fuse_reduce_mean = [&rewriter, &find_tool](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
MGB_MARK_USED_VAR(rewriter);
// check true_div
auto elemwise = try_cast_as_op<opr::Elemwise>(opr);
CHECK_OR_RETURN(elemwise);
auto mode_ok = elemwise->param().mode == opr::Elemwise::Mode::TRUE_DIV;
CHECK_OR_RETURN(mode_ok);
auto input0 = elemwise->input(0)->owner_opr();
auto remove_axis = input0->try_cast_final<opr::AxisAddRemove>();
auto reduce = input0->try_cast_final<opr::Reduce>();
if (remove_axis) {
reduce = remove_axis->input(0)->owner_opr()->try_cast_final<opr::Reduce>();
}
CHECK_OR_RETURN(reduce);
bool reduce_sum = reduce->param().mode == opr::Reduce::Param::Mode::SUM;
CHECK_OR_RETURN(reduce_sum);
auto input1 = elemwise->input(1)->owner_opr();
auto typecvt = input1->try_cast_final<opr::TypeCvt>();
CHECK_OR_RETURN(typecvt);
auto is_typecvt_f32 = typecvt->param().enumv() == DTypeEnum::Float32;
CHECK_OR_RETURN(is_typecvt_f32);
auto get_var_shape =
typecvt->input(0)->owner_opr()->try_cast_final<opr::GetVarShape>();
CHECK_OR_RETURN(get_var_shape);
bool same_parent =
get_var_shape->input(0)->owner_opr() == reduce->input(0)->owner_opr();
CHECK_OR_RETURN(same_parent);
CHECK_OR_RETURN(
find_tool.used_count(get_var_shape->input(0)->owner_opr()) == 2);
bool same_axis = get_var_shape->param().axis == reduce->param().axis;
CHECK_OR_RETURN(same_axis);
auto new_reduce_param = reduce->param();
new_reduce_param.mode = opr::Reduce::Mode::MEAN;
auto new_node =
opr::Reduce::make(rewriter.get_var(reduce->input(0)), new_reduce_param);
if (remove_axis) {
new_node = opr::AxisAddRemove::make(
new_node, remove_axis->param(), remove_axis->config());
}
rewriter.replace_var(
opr->output(0), new_node.node(),
mgb_cstr_log("replace reduce_sum+div_axis -> reduce_mean"));
return true;
};
auto on_opr = [&try_fuse_reduce_mean, &rewriter](OperatorNodeBase* opr) {
if (!try_fuse_reduce_mean(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -722,7 +722,10 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
options.disable_##_option(); \
} \
}
cb(fuse_grain, {
add_pass<FoldingReduceMeanPass>();
add_pass<FoldingGlobalPoolingPass>();
});
cb(fuse_preprocess, {
add_pass(FuseNCHW4Int8Preprocess::make());
add_pass<FuseWarpPerspectiveDimshufflePass>();
......
#pragma once
#include "../../core/impl/graph/cg_impl.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/utils/hash_ct.h"
namespace mgb {
namespace gopt {
namespace {
#define CHECK_OR_RETURN(x) \
if (!(x)) { \
return false; \
}
class FindNext {
using DepType = cg::OperatorNodeProp::DepType;
public:
FindNext(OptState& opt) {
opt.graph().iter([&](OperatorNodeBase* opr) {
for (auto&& i : opr->node_prop().dep_map()) {
m_readers[i.first->owner_opr()].emplace_back(opr, i.second);
}
});
}
size_t used_count(OperatorNodeBase* opr) { return m_readers[opr].size(); }
private:
ThinHashMap<OperatorNodeBase*, SmallVector<std::pair<OperatorNodeBase*, DepType>>>
m_readers;
};
} // namespace
} // namespace gopt
} // namespace mgb
\ No newline at end of file
......@@ -4,6 +4,7 @@
#include "megbrain/graph/event.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/local.h"
......@@ -1368,6 +1369,8 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
megdnn::param::Convolution::Format::NCHW88;
megdnn::param::Pooling::Format pooling_format =
megdnn::param::Pooling::Format::NCHW88;
megdnn::param::AdaptivePooling::Format adapt_pooling_format =
megdnn::param::AdaptivePooling::Format::NCHW88;
megdnn::param::Resize::Format resize_format = megdnn::param::Resize::Format::NCHW88;
std::string convter_pass_name = "conv_format_nchw88";
......@@ -1381,6 +1384,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
conv_bias_format = megdnn::param::ConvBias::Format::NCHW44;
conv_format = megdnn::param::Convolution::Format::NCHW44;
pooling_format = megdnn::param::Pooling::Format::NCHW44;
adapt_pooling_format = megdnn::param::AdaptivePooling::Format::NCHW44;
resize_format = megdnn::param::Resize::Format::NCHW44;
convter_pass_name = "conv_format_nchw44";
}
......@@ -1646,6 +1650,33 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
return new_opr;
}
};
auto replace_adapt_pooling_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& pooling_opr = opr->cast_final_safe<opr::AdaptivePooling>();
mgb_throw_if(
pooling_opr.param().format !=
opr::AdaptivePoolingForward::Param::Format::NCHW,
MegBrainError,
"ConvertFormat Pass only support converting NCHW to NCHWxx");
VarNode* inp_0 = new_inp[0];
VarNode* inp_1 = new_inp[1];
//! if input is nchwxx
if (inp_0->shape().ndim == 5) {
auto new_param = pooling_opr.param();
new_param.format = adapt_pooling_format;
auto new_pooling_opr = opr::AdaptivePoolingForward::make(
inp_0, inp_1, new_param, opr->config());
mgb_assert(
new_pooling_opr.shape().ndim == 5,
"The pooling dst dim is not trans to nchwxx");
return new_pooling_opr.node()->owner_opr();
} else {
auto new_opr =
serialization::copy_opr_shallow(*opr, new_inp, opr->config());
return new_opr;
}
};
auto replace_resize_opr = [=](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
......@@ -1763,6 +1794,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
replace_func[opr::AdaptivePooling::typeinfo()] = replace_adapt_pooling_opr;
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
replace_func[opr::Concat::typeinfo()] = replace_multi_inp_opr;
replace_func[opr::Elemwise::typeinfo()] = replace_multi_inp_opr;
......
......@@ -327,6 +327,8 @@ struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {
ret |= 1u << 4;
if (fuse_preprocess)
ret |= 1u << 5;
if (fuse_grain)
ret |= 1u << 6;
return ret;
}
......@@ -338,6 +340,7 @@ struct OptimizeForInferenceOptions : cg::GraphCommonOptimizeOptions {
ret.fuse_conv_bias_with_z = buf & 1u << 3;
ret.weight_preprocess = buf & 1u << 4;
ret.fuse_preprocess = buf & 1u << 5;
ret.fuse_grain = buf & 1u << 6;
ret.layout_transform = (LayoutTransform)(buf >> 32);
return ret;
}
......@@ -477,6 +480,25 @@ public:
void apply(OptState& opt) const override;
};
#endif
/**
* \brief old megbrain support reduce_mean by reduce_sum and div, fuse it for efficient
*
*/
class FoldingReduceMeanPass final : public Pass {
public:
const char* name() const override;
void apply(OptState& opt) const override;
};
/**
* \brief fold reduce hw to global pooling, for nchwxx optimize
*
*/
class FoldingGlobalPoolingPass final : public Pass {
public:
const char* name() const override;
void apply(OptState& opt) const override;
};
/*!
* \brief padding channel to enable fast int8/int4 support
......
#include "megbrain/opr/dnn/local.h"
#include "megbrain/test/helper.h"
#include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_gen.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "./helper.h"
#include "megbrain/comp_node_env.h"
#include "megdnn/tensor_format.h"
#include <random>
#include <vector>
using namespace mgb;
namespace {
//! find first the operator of specific type; raise exception if not found
template <typename T>
T& find_opr(SymbolVar endpoint) {
T* found = nullptr;
auto cb = [&found](cg::OperatorNodeBase* opr) {
if (!found && opr->same_type<T>()) {
found = &opr->cast_final_safe<T>();
}
};
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr());
mgb_assert(found, "not found opr from %s", endpoint.node()->name().c_str());
return *found;
}
template <typename T>
T& find_opr(SymbolVar endpoint, const std::string& node_name) {
T* found = nullptr;
auto cb = [&found, &node_name](cg::OperatorNodeBase* opr) {
if (!found && opr->same_type<T>() && opr->name() == node_name) {
found = &opr->cast_final_safe<T>();
}
};
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr());
mgb_assert(
found, "not found opr %s from %s", node_name.c_str(),
endpoint.node()->name().c_str());
return *found;
}
template <typename T>
size_t find_opr_num(SymbolVar endpoint) {
size_t opr_num = 0;
auto cb = [&opr_num](cg::OperatorNodeBase* opr) {
if (opr->same_type<T>()) {
opr_num++;
}
};
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr());
return opr_num;
}
} // namespace
TEST(TestGoptOldModel, FoldingGlobalPooling) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
opr::Convolution::Param param_conv;
param_conv.stride_h = param_conv.stride_w = 1;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3});
auto conv1 =
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1"));
auto conv_n = opr::GetVarShape::make(conv1, 0);
auto conv_c = opr::GetVarShape::make(conv1, 1);
auto conv_h = opr::GetVarShape::make(conv1, 2);
auto conv_w = opr::GetVarShape::make(conv1, 3);
auto hxw = conv_h * conv_w;
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0);
auto reshape1 = opr::Reshape::make(conv1, reshape_shape);
opr::Reduce::Param param_reduce;
param_reduce.axis = 2;
param_reduce.mode = opr::Reduce::Mode::SUM;
auto reduce = opr::Reduce::make(reshape1, param_reduce);
auto reduce_remove_axis = opr::AxisAddRemove::make(
reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)});
auto hw_count = opr::GetVarShape::make(reshape1, 2);
auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32());
auto true_div = reduce_remove_axis / fp32_hw_count;
auto y = opr::AxisAddRemove::make(
true_div, {opr::AxisAddRemove::AxisDesc::make_add(2),
opr::AxisAddRemove::AxisDesc::make_add(3)});
SymbolVar y_opt = y;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.fuse_grain = true;
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(
opr::AdaptivePooling::Param::Mode::AVERAGE,
find_opr<opr::AdaptivePooling>(y_opt).param().mode);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file("TestGoptOldModel.FoldingGlobalPooling.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptOldModel, FoldingGlobalPooling2) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
opr::Convolution::Param param_conv;
param_conv.stride_h = param_conv.stride_w = 1;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3});
auto conv1 =
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1"));
auto conv_n = opr::GetVarShape::make(conv1, 0);
auto conv_c = opr::GetVarShape::make(conv1, 1);
auto conv_h = opr::GetVarShape::make(conv1, 2);
auto conv_w = opr::GetVarShape::make(conv1, 3);
auto hxw = conv_h * conv_w;
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0);
auto reshape1 = opr::Reshape::make(conv1, reshape_shape);
opr::Reduce::Param param_reduce;
param_reduce.axis = 2;
param_reduce.mode = opr::Reduce::Mode::SUM;
auto reduce = opr::Reduce::make(reshape1, param_reduce);
auto reduce_remove_axis = opr::AxisAddRemove::make(
reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)});
auto hw_count = opr::GetVarShape::make(reshape1, 2);
auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32());
auto true_div = reduce_remove_axis / fp32_hw_count;
auto y = opr::Dimshuffle::make(true_div, {0, 1, -1, -1});
SymbolVar y_opt = y;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.fuse_grain = true;
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(
opr::AdaptivePooling::Param::Mode::AVERAGE,
find_opr<opr::AdaptivePooling>(y_opt).param().mode);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file("TestGoptOldModel.FoldingGlobalPooling2.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptOldModel, FoldingReduceMean) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x);
opr::Convolution::Param param_conv;
param_conv.stride_h = param_conv.stride_w = 1;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 3, 3, 3});
auto conv1 =
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1"));
auto conv_n = opr::GetVarShape::make(conv1, 0);
auto conv_c = opr::GetVarShape::make(conv1, 1);
auto conv_h = opr::GetVarShape::make(conv1, 2);
auto conv_w = opr::GetVarShape::make(conv1, 3);
auto hxw = conv_h * conv_w;
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0);
auto reshape1 = opr::Reshape::make(conv1, reshape_shape);
opr::Reduce::Param param_reduce;
param_reduce.axis = 2;
param_reduce.mode = opr::Reduce::Mode::SUM;
auto reduce = opr::Reduce::make(reshape1, param_reduce);
auto hw_count = opr::GetVarShape::make(reshape1, 2);
auto y = reduce / hw_count;
SymbolVar y_opt = y;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.fuse_grain = true;
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(
opr::Reduce::Param::Mode::MEAN, find_opr<opr::Reduce>(y_opt).param().mode);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file("TestGoptOldModel.FoldingReduceMean.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
*host_x = *gen({2, 3, 16, 16}, cn);
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
......@@ -7,6 +7,7 @@
#include "megbrain/opr/basic_arith_wrapper.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h"
......@@ -4113,6 +4114,71 @@ TEST(TestGoptInference, ConvertFormatNCHW44Reshape) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
}
TEST(TestGoptInference, ConvertFormatNCHW44GlobalPooling) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkcvar = [&](const char* name, const TensorShape& shp) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name);
};
auto host_x1 = gen({1, 4, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x1);
opr::Convolution::Param param_conv;
param_conv.stride_h = param_conv.stride_w = 1;
param_conv.pad_h = param_conv.pad_w = 1;
auto w1 = mkcvar("w1", {8, 4, 3, 3});
auto conv1 =
opr::Convolution::make(x, w1, param_conv, {}, OperatorNodeConfig("conv1"));
auto conv_n = opr::GetVarShape::make(conv1, 0);
auto conv_c = opr::GetVarShape::make(conv1, 1);
auto conv_h = opr::GetVarShape::make(conv1, 2);
auto conv_w = opr::GetVarShape::make(conv1, 3);
auto hxw = conv_h * conv_w;
auto reshape_shape = opr::Concat::make({conv_n, conv_c, hxw}, 0);
auto reshape1 = opr::Reshape::make(conv1, reshape_shape);
opr::Reduce::Param param_reduce;
param_reduce.axis = 2;
param_reduce.mode = opr::Reduce::Mode::SUM;
auto reduce = opr::Reduce::make(reshape1, param_reduce);
auto reduce_remove_axis = opr::AxisAddRemove::make(
reduce, {opr::AxisAddRemove::AxisDesc::make_remove(2)});
auto hw_count = opr::GetVarShape::make(reshape1, 2);
auto fp32_hw_count = opr::TypeCvt::make(hw_count, dtype::Float32());
auto reduce_mean = reduce_remove_axis / fp32_hw_count;
auto global_pool = opr::AxisAddRemove::make(
reduce_mean, {opr::AxisAddRemove::AxisDesc::make_add(2),
opr::AxisAddRemove::AxisDesc::make_add(3)});
opr::Elemwise::Param elem_param;
elem_param.mode = opr::Elemwise::Param::Mode::RELU;
auto y = opr::Elemwise::make({global_pool}, elem_param);
SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_grain();
options.enable_nchw44();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
ASSERT_EQ(
opr::AdaptivePooling::Param::Format::NCHW44,
find_opr<opr::AdaptivePooling>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.ConvertFormatNCHW44GlobalPooling.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile(
{make_callback_copy(y, host_y), make_callback_copy(y_opt, host_y_opt)});
func->execute();
//! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
}
TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
......
......@@ -39,7 +39,7 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape(
cg::copy_tensor_value_to_shape(oshp2d, *shpinfo.shpval_inp_val.at(0));
auto src = shpinfo.shape_inp_shp.at(0);
mgb_assert(
src.ndim == 4 && (oshp2d.ndim == 2 || oshp2d.ndim == 1),
(src.ndim == 4 || src.ndim == 5) && (oshp2d.ndim == 2 || oshp2d.ndim == 1),
"shape mismatch for AdaptivePooling: src=%s, out2d=%s",
src.to_string().c_str(), oshp2d.to_string().c_str());
......@@ -57,8 +57,19 @@ void AdaptivePoolingForward::outshape_by_symvar_do_get_output_shape(
dest.shape[1] = oshp2d.shape[0];
dest.shape[2] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1];
dest.shape[3] = src.shape[3];
} else if (
param_format == Param::Format::NCHW44 ||
param_format == Param::Format::NCHW88) {
dest.ndim = 5;
dest.shape[0] = src.shape[0];
dest.shape[1] = src.shape[1];
dest.shape[2] = oshp2d.shape[0];
dest.shape[3] = (tshp1n) ? oshp2d.shape[0] : oshp2d.shape[1];
dest.shape[4] = src.shape[4];
} else {
mgb_throw(MegBrainError, "AdaptivePooling only support NCHW or NHWC format");
mgb_throw(
MegBrainError, "AdaptivePooling not support %d format",
(int)param_format);
}
}
......
......@@ -48,7 +48,6 @@ void run(Param::Mode mode) {
Checker::RunOptions opt;
opt.numdiff_max_err = 1e-2;
Checker checker{make_graph, fwd};
checker.set_input_allow_grad(1, false).set_input_generator(0, gen);
checker.run({TensorShape{1, 1, 10, 7}, TensorShape{5, 4}}, opt);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册