提交 a6230ba9 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): global layout transform support arm

GitOrigin-RevId: db50b33c112b99ab6f34cd81d9cf62790fc87c6e
上级 0be6ca88
...@@ -820,23 +820,26 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options( ...@@ -820,23 +820,26 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
_passes need_param_fuse = true; \ _passes need_param_fuse = true; \
} }
using Target = GraphTuningOptions::Target;
cb(layout_transform, { cb(layout_transform, {
add_pass<FuseConvBiasNonlinPass>(); add_pass<FuseConvBiasNonlinPass>();
if (options.target == Target::CUDA)
add_pass<FuseConvBiasZPass>(); add_pass<FuseConvBiasZPass>();
add_pass(LayoutTransformPass::make(options.target)); add_pass(LayoutTransformPass::make(options.target));
add_pass<ShuffleShuffleRemovePass>(); add_pass<ShuffleShuffleRemovePass>();
if (options.target == Target::CUDA) {
add_pass(FuseNCHW4Int8Preprocess::make()); add_pass(FuseNCHW4Int8Preprocess::make());
add_pass<FuseWarpPerspectiveDimshufflePass>(); add_pass<FuseWarpPerspectiveDimshufflePass>();
#if CUDA_VERSION >= 10020 #if CUDA_VERSION >= 10020
add_pass<FoldingConvBiasDimshufflePass>(); add_pass<FoldingConvBiasDimshufflePass>();
add_pass<FoldingConvBiasTypecvtPass>(); add_pass<FoldingConvBiasTypecvtPass>();
#endif #endif
}
}); });
#undef cb #undef cb
if (need_param_fuse) { if (need_param_fuse) {
add_pass<ParamFusePass>(); add_pass<ParamFusePass>();
add_pass<ParamMergePass>();
} }
return *this; return *this;
} }
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "megbrain/opr/dnn/pooling.h" #include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h" #include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h" #include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_manip.h"
using namespace mgb; using namespace mgb;
using namespace gopt; using namespace gopt;
...@@ -82,6 +83,44 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( ...@@ -82,6 +83,44 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
{OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64}); {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64});
return ctx; return ctx;
} }
std::unique_ptr<LayoutTransformContext> make_arm_ctx(
OprFormat base_opr_format, TensorFormats base_tensor_format) {
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ConvolutionForward::typeinfo(),
opr::ElemwiseMultiType::typeinfo(),
opr::Elemwise::typeinfo(),
opr::TypeCvt::typeinfo(),
opr::PoolingForward::typeinfo(),
opr::Resize::typeinfo(),
opr::PowC::typeinfo(),
opr::Concat::typeinfo(),
};
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NCHWc4,
DNN_INC_FLOAT16(TensorFormats::NCHWc8)};
Attribute attribute = {base_opr_format, base_tensor_format, Target::ARM};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats),
attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW44,
DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT})
.add_opr_config(
opr::ConvolutionForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW44,
DNN_INC_FLOAT16(OprFormat::NCHW88), OprFormat::NCHW44_DOT})
.add_opr_config(opr::PoolingForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW44,
DNN_INC_FLOAT16(OprFormat::NCHW88)})
.add_opr_config(opr::ResizeForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW44,
DNN_INC_FLOAT16(OprFormat::NCHW88)});
return ctx;
}
} // namespace } // namespace
/* ================= LayoutTransformContext ==================*/ /* ================= LayoutTransformContext ==================*/
...@@ -110,6 +149,8 @@ std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make( ...@@ -110,6 +149,8 @@ std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make(
switch (target) { switch (target) {
case Target::CUDA: case Target::CUDA:
return make_cuda_ctx(base_opr_format, base_tensor_format); return make_cuda_ctx(base_opr_format, base_tensor_format);
case Target::ARM:
return make_arm_ctx(base_opr_format, base_tensor_format);
default: default:
mgb_assert(false, "unsupported target %s\n", target_to_string(target)); mgb_assert(false, "unsupported target %s\n", target_to_string(target));
} }
......
...@@ -60,6 +60,7 @@ void LayoutTransformPass::apply(OptState& opt) const { ...@@ -60,6 +60,7 @@ void LayoutTransformPass::apply(OptState& opt) const {
auto&& opr_configs = m_ctx->opr_configs(); auto&& opr_configs = m_ctx->opr_configs();
auto&& base_fmt = m_ctx->attribute().base_tensor_formats; auto&& base_fmt = m_ctx->attribute().base_tensor_formats;
auto&& base_opr_fmt = m_ctx->attribute().base_opr_format;
auto&& reformat_attribute = m_ctx->attribute().reformat_attribute; auto&& reformat_attribute = m_ctx->attribute().reformat_attribute;
ThinHashMap<VarNode*, TensorFormats> var2fmts; ThinHashMap<VarNode*, TensorFormats> var2fmts;
static ThinHashSet<Typeinfo*> format_aware_oprs = { static ThinHashSet<Typeinfo*> format_aware_oprs = {
...@@ -68,15 +69,18 @@ void LayoutTransformPass::apply(OptState& opt) const { ...@@ -68,15 +69,18 @@ void LayoutTransformPass::apply(OptState& opt) const {
#undef cb #undef cb
}; };
auto rewriter = opt.graph().make_rewriter(); auto rewriter = opt.graph().make_rewriter();
auto on_opr = [&opr_configs, &base_fmt, &reformat_attribute, &rewriter, &solution, auto on_opr = [&opr_configs, &base_fmt, &base_opr_fmt, &reformat_attribute,
&var2fmts, &endpoint_vars](OperatorNodeBase* opr) { &rewriter, &solution, &var2fmts,
&endpoint_vars](OperatorNodeBase* opr) {
auto it = solution.find(opr); auto it = solution.find(opr);
if (it != solution.end()) { if (it != solution.end()) {
auto opr_fmt = it->second; auto opr_fmt = it->second;
auto find = opr_configs.find(opr->dyn_typeinfo()); auto find = opr_configs.find(opr->dyn_typeinfo());
Maybe<OprTensorFormatsConfiguration> fmtcfg = None; Maybe<OprTensorFormatsConfiguration> fmtcfg = None;
Maybe<OprTensorFormatsConfiguration> basecfg = None;
if (find != opr_configs.end()) { if (find != opr_configs.end()) {
fmtcfg = (*find->second.at(opr_fmt))(opr); fmtcfg = (*find->second.at(opr_fmt))(opr);
basecfg = (*find->second.at(base_opr_fmt))(opr);
} }
VarNodeArray new_inp; VarNodeArray new_inp;
size_t nr_inps = opr->input().size(); size_t nr_inps = opr->input().size();
...@@ -103,6 +107,10 @@ void LayoutTransformPass::apply(OptState& opt) const { ...@@ -103,6 +107,10 @@ void LayoutTransformPass::apply(OptState& opt) const {
bool is_parameter = bool is_parameter =
fmtcfg.valid() && fmtcfg.valid() &&
fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT; fmtcfg.val().input_tensor_types[i] == TensorType::WEIGHT;
if (is_parameter) {
mgb_assert(basecfg.valid());
from = basecfg.val().input_tensor_formats[i];
}
// need relayout // need relayout
if (from != to && !new_var->shape().is_scalar()) { if (from != to && !new_var->shape().is_scalar()) {
ReformatManager::ReformatImpl reformat; ReformatManager::ReformatImpl reformat;
......
...@@ -78,6 +78,48 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> { ...@@ -78,6 +78,48 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW> {
} }
}; };
template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW44> {
static Maybe<OprTensorFormatsConfiguration> dispatch(
const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW44;
bool available = true;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32;
config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE};
config.output_dtypes = {opr->output(0)->dtype().enumv()};
config.input_tensor_formats = {TensorFormats::NCHWc4};
config.output_tensor_formats = {TensorFormats::NCHWc4};
if (!available)
return None;
return config;
}
};
#if !MEGDNN_DISABLE_FLOAT16
template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW88> {
static Maybe<OprTensorFormatsConfiguration> dispatch(
const OperatorNodeBase* opr) {
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW88;
bool available = true;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16;
config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE};
config.output_dtypes = {opr->output(0)->dtype().enumv()};
config.input_tensor_formats = {TensorFormats::NCHWc8};
config.output_tensor_formats = {TensorFormats::NCHWc8};
if (!available)
return None;
return config;
}
};
#endif
template <> template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> { struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormat::NCHW4> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
...@@ -200,7 +242,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW> { ...@@ -200,7 +242,7 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW> {
// setup tensor formats // setup tensor formats
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { if (conv.param().sparse == Opr::Param::Sparse::DENSE) {
config.input_tensor_formats = { config.input_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::NCHW, TensorFormats::KCRS, TensorFormats::NCHW,
TensorFormats::NCHW}; TensorFormats::NCHW};
} else { } else {
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP); mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP);
...@@ -396,6 +438,145 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> { ...@@ -396,6 +438,145 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::CHWN4> {
} }
}; };
template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44> {
static Maybe<OprTensorFormatsConfiguration> dispatch(
const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW44;
bool available = true;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32;
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type =
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
// setup tensor formats
if (conv.param().sparse == Opr::Param::Sparse::DENSE) {
config.input_tensor_formats = {
TensorFormats::NCHWc4, TensorFormats::KCRSc4k4,
TensorFormats::NCHWc4, TensorFormats::NCHWc4};
} else {
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP);
if (is_channel_wise_conv<Opr>(opr)) {
config.input_tensor_formats = {
TensorFormats::NCHWc4, TensorFormats::C11RSc4,
TensorFormats::NCHWc4, TensorFormats::NCHWc4};
} else {
config.input_tensor_formats = {
TensorFormats::NCHWc4, TensorFormats::GKCRSc4k4,
TensorFormats::NCHWc4, TensorFormats::NCHWc4};
}
}
config.output_tensor_formats = {TensorFormats::NCHWc4};
if (!available)
return None;
return config;
}
};
#if !MEGDNN_DISABLE_FLOAT16
template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW88> {
static Maybe<OprTensorFormatsConfiguration> dispatch(
const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW88;
bool available = true;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16;
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type =
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
// setup tensor formats
if (conv.param().sparse == Opr::Param::Sparse::DENSE) {
config.input_tensor_formats = {
TensorFormats::NCHWc8, TensorFormats::KCRSc8k8,
TensorFormats::NCHWc8, TensorFormats::NCHWc8};
} else {
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP);
if (is_channel_wise_conv<Opr>(opr)) {
config.input_tensor_formats = {
TensorFormats::NCHWc8, TensorFormats::C11RSc8,
TensorFormats::NCHWc8, TensorFormats::NCHWc8};
} else {
config.input_tensor_formats = {
TensorFormats::NCHWc8, TensorFormats::GKCRSc8k8,
TensorFormats::NCHWc8, TensorFormats::NCHWc8};
}
}
config.output_tensor_formats = {TensorFormats::NCHWc8};
if (!available)
return None;
return config;
}
};
#endif
template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormat::NCHW44_DOT> {
static Maybe<OprTensorFormatsConfiguration> dispatch(
const OperatorNodeBase* opr) {
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW44_DOT;
bool available = true;
// setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) {
if (i == 2) {
available &= opr->input(i)->dtype().enumv() ==
DTypeEnum::QuantizedS32;
} else {
available &= opr->input(i)->dtype().enumv() ==
DTypeEnum::QuantizedS8 ||
opr->input(i)->dtype().enumv() ==
DTypeEnum::Quantized8Asymm;
}
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type =
i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &=
opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
// setup tensor formats
if (conv.param().sparse == Opr::Param::Sparse::DENSE) {
config.input_tensor_formats = {
TensorFormats::NCHWc4, TensorFormats::KCRSk4c4,
TensorFormats::NCHWc4, TensorFormats::NCHWc4};
} else {
mgb_assert(conv.param().sparse == Opr::Param::Sparse::GROUP);
if (is_channel_wise_conv<Opr>(opr)) {
available = false;
} else {
config.input_tensor_formats = {
TensorFormats::NCHWc4, TensorFormats::GKCRSk4c4,
TensorFormats::NCHWc4, TensorFormats::NCHWc4};
}
}
config.output_tensor_formats = {TensorFormats::NCHWc4};
if (!available)
return None;
return config;
}
};
template <> template <>
struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW> { struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData, OprFormat::NCHW> {
using Opr = opr::ConvolutionBackwardData; using Opr = opr::ConvolutionBackwardData;
...@@ -530,9 +711,19 @@ StaticData::StaticData() { ...@@ -530,9 +711,19 @@ StaticData::StaticData() {
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44);
#if !MEGDNN_DISABLE_FLOAT16
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88);
#endif
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44);
#if !MEGDNN_DISABLE_FLOAT16
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88);
#endif
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NCHW);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionBackwardData, NHWC);
...@@ -549,6 +740,16 @@ StaticData::StaticData() { ...@@ -549,6 +740,16 @@ StaticData::StaticData() {
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, CHWN4); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, CHWN4);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44);
#if !MEGDNN_DISABLE_FLOAT16
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88);
#endif
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44);
#if !MEGDNN_DISABLE_FLOAT16
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88);
#endif
#undef OPR_TENSOR_FORMATS_CONFIG_REG #undef OPR_TENSOR_FORMATS_CONFIG_REG
#undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
......
...@@ -35,9 +35,9 @@ OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) { ...@@ -35,9 +35,9 @@ OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) {
case TensorFormats::NCHW: case TensorFormats::NCHW:
return OprFormat::NCHW; return OprFormat::NCHW;
case TensorFormats::NCHWc4: case TensorFormats::NCHWc4:
return OprFormat::NCHW4; return OprFormat::NCHW44;
case TensorFormats::NCHWc8: case TensorFormats::NCHWc8:
return OprFormat::NCHW8; return OprFormat::NCHW88;
case TensorFormats::NCHWc32: case TensorFormats::NCHWc32:
return OprFormat::NCHW32; return OprFormat::NCHW32;
case TensorFormats::NCHWc64: case TensorFormats::NCHWc64:
...@@ -424,11 +424,11 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons ...@@ -424,11 +424,11 @@ ProfilerImpl::ProfilingResult ProfilerImpl::profile(const Problem& problem) cons
skip &= problem.graph_partition().input().count(i) > 0 || skip &= problem.graph_partition().input().count(i) > 0 ||
skip_oprs.count(i->owner_opr()) > 0; skip_oprs.count(i->owner_opr()) > 0;
} }
skip &= skip_opr_types.count(opr->dyn_typeinfo()); auto find = format_aware_input_tensors.find(opr->dyn_typeinfo());
skip &= find == format_aware_input_tensors.end();
if (skip) if (skip)
skip_oprs.insert(opr); skip_oprs.insert(opr);
oprs.insert(opr); oprs.insert(opr);
auto find = format_aware_input_tensors.find(opr->dyn_typeinfo());
if (find == format_aware_input_tensors.end()) { if (find == format_aware_input_tensors.end()) {
for (auto&& i : opr->input()) { for (auto&& i : opr->input()) {
if (!cvprop.is_const(i)) { if (!cvprop.is_const(i)) {
......
...@@ -470,9 +470,9 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( ...@@ -470,9 +470,9 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight(
input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) { input_shape[i].extent() == Dimension::UNDETERMINED_EXTENT) {
in_channels = orig_var->shape()[i] * input_shape[i].stride(); in_channels = orig_var->shape()[i] * input_shape[i].stride();
input_channel_idx = i; input_channel_idx = i;
// mgb_assert(input_shape[i].stride() == 1, mgb_assert(
// "unsupport weight format(got:%s)", input_shape[i].stride() == 1, "unsupport weight format(got:%s)",
// input_shape.to_string().c_str()); input_shape.to_string().c_str());
} else if ( } else if (
(input_shape[i].name() == Dimension::Name::K || (input_shape[i].name() == Dimension::Name::K ||
input_shape[i].name() == Dimension::Name::N) && input_shape[i].name() == Dimension::Name::N) &&
...@@ -485,13 +485,23 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( ...@@ -485,13 +485,23 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight(
input_shape.to_string().c_str()); input_shape.to_string().c_str());
} }
} }
/* \notes: FIXME this is a hack. Since the layout of weight in channelwise
* convolution does not have output channel dimension, so we mannually modify the
* out_channel_name, out_channel_idx to bypass the following assertion statements. */
bool is_channelwise = key.input_format == TensorFormats::C11RS;
if (is_channelwise) {
out_channel_name = Dimension::Name::K;
out_channels = in_channels;
output_channel_idx = input_channel_idx;
}
mgb_assert( mgb_assert(
out_channel_name == Dimension::Name::K || out_channel_name == Dimension::Name::K ||
out_channel_name == Dimension::Name::N, out_channel_name == Dimension::Name::N,
"invalid out channel(shp:%s)", input_shape.to_string().c_str()); "invalid out channel(shp:%s)", input_shape.to_string().c_str());
mgb_assert( mgb_assert(
input_channel_idx < input_shape.ndim && (input_channel_idx < input_shape.ndim &&
output_channel_idx < input_shape.ndim, output_channel_idx < input_shape.ndim) ||
(is_channelwise && output_channel_idx == input_channel_idx),
"invalid channel idx(in_channel:%zu, out_channel:%zu, shp:%s)", "invalid channel idx(in_channel:%zu, out_channel:%zu, shp:%s)",
input_channel_idx, output_channel_idx, input_shape.to_string().c_str()); input_channel_idx, output_channel_idx, input_shape.to_string().c_str());
size_t in_channel_alignment = 0, out_channel_alignment = 0; size_t in_channel_alignment = 0, out_channel_alignment = 0;
...@@ -506,6 +516,13 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight( ...@@ -506,6 +516,13 @@ ReformatManager::ReformatImpl ReformatManager::auto_aligned_reformat_weight(
out_channel_alignment = output_shape[i].stride(); out_channel_alignment = output_shape[i].stride();
} }
} }
/* \notes: FIXME this is a hack. Since the layout of weight in channelwise
* convolution does not have output channel dimension, so we mannually modify the
* out_channel_alignment to bypass the following assertion statements. */
if (is_channelwise) {
mgb_assert(out_channel_alignment == 0);
out_channel_alignment = 1;
}
mgb_assert( mgb_assert(
in_channel_alignment > 0 && out_channel_alignment > 0, in_channel_alignment > 0 && out_channel_alignment > 0,
"invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)", "invalid alignment(in_channel:%zu, out_channel:%zu, shp:%s)",
......
...@@ -263,20 +263,9 @@ std::vector<GraphPartition> SubGraphExtractor::extract( ...@@ -263,20 +263,9 @@ std::vector<GraphPartition> SubGraphExtractor::extract(
std::vector<GraphPartition> partitions; std::vector<GraphPartition> partitions;
partitions.reserve(topo.size()); partitions.reserve(topo.size());
ThinHashMap<OperatorNodeBase*, GraphPartition*> roots; ThinHashMap<OperatorNodeBase*, GraphPartition*> roots;
/// backward pass
for (const auto& opr : reverse_adaptor(topo)) { for (const auto& opr : reverse_adaptor(topo)) {
if (m_opr_list.count(opr->dyn_typeinfo()) == 0) { if (m_opr_list.count(opr->dyn_typeinfo()) > 0) {
for (const auto& i : opr->input()) {
if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) {
auto root = union_find(i->owner_opr());
GraphPartition* partition;
auto find = roots.find(root);
if (find != roots.end()) {
partition = find->second;
partition->output().insert(i);
}
}
}
} else {
auto root = union_find(opr); auto root = union_find(opr);
auto find = roots.find(root); auto find = roots.find(root);
GraphPartition* partition = nullptr; GraphPartition* partition = nullptr;
...@@ -304,6 +293,23 @@ std::vector<GraphPartition> SubGraphExtractor::extract( ...@@ -304,6 +293,23 @@ std::vector<GraphPartition> SubGraphExtractor::extract(
partition->input().insert(i); partition->input().insert(i);
} }
} }
/// forward pass
for (auto&& opr : topo) {
if (m_opr_list.count(opr->dyn_typeinfo()) == 0) {
for (const auto& i : opr->input()) {
if (m_opr_list.count(i->owner_opr()->dyn_typeinfo())) {
auto root = union_find(i->owner_opr());
GraphPartition* partition;
auto find = roots.find(root);
if (find != roots.end()) {
partition = find->second;
partition->output().insert(i);
}
}
}
}
}
for (auto&& partition : partitions) { for (auto&& partition : partitions) {
auto& all_oprs = partition.all_oprs(); auto& all_oprs = partition.all_oprs();
std::reverse(all_oprs.begin(), all_oprs.end()); std::reverse(all_oprs.begin(), all_oprs.end());
......
...@@ -29,6 +29,9 @@ static inline const char* opr_format_to_string( ...@@ -29,6 +29,9 @@ static inline const char* opr_format_to_string(
cb(NCHW32); cb(NCHW32);
cb(NCHW64); cb(NCHW64);
cb(CHWN4); cb(CHWN4);
cb(NCHW44);
cb(NCHW88);
cb(NCHW44_DOT);
default: default:
mgb_assert( mgb_assert(
false, "Invalid opr format(got:%u)", false, "Invalid opr format(got:%u)",
...@@ -53,6 +56,10 @@ static inline TensorFormats opr_format_to_tensor_formats( ...@@ -53,6 +56,10 @@ static inline TensorFormats opr_format_to_tensor_formats(
return TensorFormats::NCHWc64; return TensorFormats::NCHWc64;
case OprFormat::CHWN4: case OprFormat::CHWN4:
return TensorFormats::CHWNc4; return TensorFormats::CHWNc4;
case OprFormat::NCHW88:
return TensorFormats::NCHWc8;
case OprFormat::NCHW44:
return TensorFormats::NCHWc4;
default: default:
mgb_throw( mgb_throw(
AssertionError, "format(%s) is not supported", AssertionError, "format(%s) is not supported",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册