提交 ca7cec7a 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): minor fixes for global layout transform

Merge Target::ARM and Target::X86 into Target::CPU to make global layout transform easier to use

GitOrigin-RevId: cc9363fa380896b792874206dbf40f3acb07028c
上级 93152dfa
...@@ -830,9 +830,9 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -830,9 +830,9 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]); src[3], cflt.dilated_spatial[1], cflt.stride[1], cflt.padding[1]);
dst[4] = 32; dst[4] = 32;
} else if (param().format == Param::Format::NCHW88) { } else if (param().format == Param::Format::NCHW88) {
megdnn_assert(src.ndim == 5 || src.ndim == 4, megdnn_assert(
"invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim == 5 || src.ndim == 4,
src.ndim); "invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim);
dst.ndim = 5; dst.ndim = 5;
dst[0] = src[0]; dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group; auto oc = cflt.ocpg * cflt.group;
...@@ -850,11 +850,12 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet ...@@ -850,11 +850,12 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
"%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group); "%s icpg=%u group=%u", errmsg().c_str(), cflt.icpg, cflt.group);
} }
} else if (param().format == Param::Format::NCHW44 || } else if (
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_DOT) { param().format == Param::Format::NCHW44_DOT) {
megdnn_assert(src.ndim == 5 || src.ndim == 4, megdnn_assert(
"invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim == 5 || src.ndim == 4,
src.ndim); "invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim);
dst.ndim = 5; dst.ndim = 5;
dst[0] = src[0]; dst[0] = src[0];
auto oc = cflt.ocpg * cflt.group; auto oc = cflt.ocpg * cflt.group;
......
...@@ -491,7 +491,6 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve( ...@@ -491,7 +491,6 @@ DynamicProgrammingSolver::Solution DynamicProgrammingSolver::Impl::solve(
auto& states = cuts.back().states; auto& states = cuts.back().states;
prune(states, edges[cur], ctx); prune(states, edges[cur], ctx);
force_prune(states); force_prune(states);
} }
cur++; cur++;
} }
......
...@@ -32,8 +32,7 @@ const char* target_to_string(Target target) { ...@@ -32,8 +32,7 @@ const char* target_to_string(Target target) {
return #_target return #_target
switch (target) { switch (target) {
cb(CUDA); cb(CUDA);
cb(X86); cb(CPU);
cb(ARM);
cb(UNSPEC); cb(UNSPEC);
default: default:
mgb_assert( mgb_assert(
...@@ -89,7 +88,7 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx( ...@@ -89,7 +88,7 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
return ctx; return ctx;
} }
std::unique_ptr<LayoutTransformContext> make_arm_ctx( std::unique_ptr<LayoutTransformContext> make_cpu_ctx(
OprFormatConfigID base_config_id, TensorFormats base_tensor_format) { OprFormatConfigID base_config_id, TensorFormats base_tensor_format) {
OprList opr_list = { OprList opr_list = {
opr::ConvBiasForward::typeinfo(), opr::ConvBiasForward::typeinfo(),
...@@ -104,34 +103,30 @@ std::unique_ptr<LayoutTransformContext> make_arm_ctx( ...@@ -104,34 +103,30 @@ std::unique_ptr<LayoutTransformContext> make_arm_ctx(
}; };
SmallVector<TensorFormats> available_tensor_formats = { SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NCHWc4, TensorFormats::NCHW, TensorFormats::NCHWc4, TensorFormats::NCHWc8};
DNN_INC_FLOAT16(TensorFormats::NCHWc8)}; Attribute attribute = {base_config_id, base_tensor_format, Target::CPU};
Attribute attribute = {base_config_id, base_tensor_format, Target::ARM};
auto ctx = std::make_unique<LayoutTransformContext>( auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute); std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config( ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(), opr::ConvBiasForward::typeinfo(),
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44,
OprFormatConfigID::NCHW44_HYBRID, OprFormatConfigID::NCHW44_HYBRID, OprFormatConfigID::NCHW88,
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88), OprFormatConfigID::NCHW88_HYBRID, OprFormatConfigID::NCHW44_DOT,
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88_HYBRID), OprFormatConfigID::NCHW44_DOT_HYBRID})
OprFormatConfigID::NCHW44_DOT, OprFormatConfigID::NCHW44_DOT_HYBRID})
.add_opr_config( .add_opr_config(
opr::ConvolutionForward::typeinfo(), opr::ConvolutionForward::typeinfo(),
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44,
OprFormatConfigID::NCHW44_HYBRID, OprFormatConfigID::NCHW44_HYBRID, OprFormatConfigID::NCHW88,
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88), OprFormatConfigID::NCHW88_HYBRID, OprFormatConfigID::NCHW44_DOT,
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88_HYBRID),
OprFormatConfigID::NCHW44_DOT,
OprFormatConfigID::NCHW44_DOT_HYBRID}) OprFormatConfigID::NCHW44_DOT_HYBRID})
.add_opr_config( .add_opr_config(
opr::PoolingForward::typeinfo(), opr::PoolingForward::typeinfo(),
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44,
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88)}) OprFormatConfigID::NCHW88})
.add_opr_config( .add_opr_config(
opr::ResizeForward::typeinfo(), opr::ResizeForward::typeinfo(),
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44, {OprFormatConfigID::NCHW, OprFormatConfigID::NCHW44,
DNN_INC_FLOAT16(OprFormatConfigID::NCHW88)}); OprFormatConfigID::NCHW88});
return ctx; return ctx;
} }
} // namespace } // namespace
...@@ -162,8 +157,8 @@ std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make( ...@@ -162,8 +157,8 @@ std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make(
switch (target) { switch (target) {
case Target::CUDA: case Target::CUDA:
return make_cuda_ctx(base_config_id, base_tensor_format); return make_cuda_ctx(base_config_id, base_tensor_format);
case Target::ARM: case Target::CPU:
return make_arm_ctx(base_config_id, base_tensor_format); return make_cpu_ctx(base_config_id, base_tensor_format);
default: default:
mgb_assert(false, "unsupported target %s\n", target_to_string(target)); mgb_assert(false, "unsupported target %s\n", target_to_string(target));
} }
......
...@@ -82,6 +82,8 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW> { ...@@ -82,6 +82,8 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW> {
} }
}; };
/* \remark: Here, maybe we needn't check data type of input and output tensors. Because
* algo available checker will skip the configuration that has no underlying impls. */
template <> template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> { struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
...@@ -89,8 +91,9 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> { ...@@ -89,8 +91,9 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> {
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW44; config.opr_format = OprFormat::NCHW44;
config.config_id = OprFormatConfigID::NCHW44; config.config_id = OprFormatConfigID::NCHW44;
bool available = true; bool f32_config = opr->input(0)->dtype().enumv() == DTypeEnum::Float32;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float32; bool i8_config = opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
bool available = f32_config || i8_config;
config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE}; config.input_tensor_types = {TensorType::FEATURE};
config.output_dtypes = {opr->output(0)->dtype().enumv()}; config.output_dtypes = {opr->output(0)->dtype().enumv()};
...@@ -102,7 +105,6 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> { ...@@ -102,7 +105,6 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW44> {
} }
}; };
#if !MEGDNN_DISABLE_FLOAT16
template <> template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> { struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
...@@ -110,8 +112,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> { ...@@ -110,8 +112,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> {
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW88; config.opr_format = OprFormat::NCHW88;
config.config_id = OprFormatConfigID::NCHW88; config.config_id = OprFormatConfigID::NCHW88;
bool available = true; bool available = opr->input(0)->dtype().enumv() == DTypeEnum::Float32;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Float16;
config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE}; config.input_tensor_types = {TensorType::FEATURE};
config.output_dtypes = {opr->output(0)->dtype().enumv()}; config.output_dtypes = {opr->output(0)->dtype().enumv()};
...@@ -122,7 +123,6 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> { ...@@ -122,7 +123,6 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW88> {
return config; return config;
} }
}; };
#endif
template <> template <>
struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW4> { struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW4> {
...@@ -131,8 +131,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW4> { ...@@ -131,8 +131,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW4> {
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW4; config.opr_format = OprFormat::NCHW4;
config.config_id = OprFormatConfigID::NCHW4; config.config_id = OprFormatConfigID::NCHW4;
bool available = true; bool available = opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE}; config.input_tensor_types = {TensorType::FEATURE};
available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
...@@ -152,8 +151,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::CHWN4> { ...@@ -152,8 +151,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::CHWN4> {
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::CHWN4; config.opr_format = OprFormat::CHWN4;
config.config_id = OprFormatConfigID::CHWN4; config.config_id = OprFormatConfigID::CHWN4;
bool available = true; bool available = opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE}; config.input_tensor_types = {TensorType::FEATURE};
available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
...@@ -173,8 +171,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW32> { ...@@ -173,8 +171,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW32> {
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW32; config.opr_format = OprFormat::NCHW32;
config.config_id = OprFormatConfigID::NCHW32; config.config_id = OprFormatConfigID::NCHW32;
bool available = true; bool available = opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
available &= opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE}; config.input_tensor_types = {TensorType::FEATURE};
available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8; available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
...@@ -194,8 +191,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NHWC> { ...@@ -194,8 +191,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NHWC> {
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NHWC; config.opr_format = OprFormat::NHWC;
config.config_id = OprFormatConfigID::NHWC; config.config_id = OprFormatConfigID::NHWC;
bool available = true; bool available = opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4;
config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE}; config.input_tensor_types = {TensorType::FEATURE};
...@@ -216,8 +212,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW64> { ...@@ -216,8 +212,7 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NCHW64> {
config.typeinfo = opr->dyn_typeinfo(); config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW64; config.opr_format = OprFormat::NCHW64;
config.config_id = OprFormatConfigID::NCHW64; config.config_id = OprFormatConfigID::NCHW64;
bool available = true; bool available = opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
available &= opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4; opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4;
config.input_dtypes = {opr->input(0)->dtype().enumv()}; config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE}; config.input_tensor_types = {TensorType::FEATURE};
...@@ -552,14 +547,24 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44> { ...@@ -552,14 +547,24 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44> {
config.opr_format = OprFormat::NCHW44; config.opr_format = OprFormat::NCHW44;
config.config_id = OprFormatConfigID::NCHW44; config.config_id = OprFormatConfigID::NCHW44;
bool available = true; bool available = true;
auto check_dtype = [](DType dt, bool is_bias) {
bool f32_config = dt.enumv() == DTypeEnum::Float32;
auto i8_dtype = DTypeEnum::QuantizedS8;
if (is_bias)
i8_dtype = DTypeEnum::QuantizedS32;
bool i8_config = dt.enumv() == i8_dtype;
return f32_config || i8_config;
};
// setup dtypes // setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) { for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; bool is_bias =
ConvParamTrait<Opr>::has_bias && i == ConvParamTrait<Opr>::bias_idx;
available &= check_dtype(opr->input(i)->dtype(), is_bias);
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type); config.input_tensor_types.emplace_back(tensor_type);
} }
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; available &= check_dtype(opr->output(0)->dtype(), false);
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
// setup tensor formats // setup tensor formats
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { if (conv.param().sparse == Opr::Param::Sparse::DENSE) {
...@@ -594,14 +599,24 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_HYBRID> { ...@@ -594,14 +599,24 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_HYBRID> {
config.opr_format = OprFormat::NCHW44; config.opr_format = OprFormat::NCHW44;
config.config_id = OprFormatConfigID::NCHW44_HYBRID; config.config_id = OprFormatConfigID::NCHW44_HYBRID;
bool available = true; bool available = true;
auto check_dtype = [](DType dt, bool is_bias) {
bool f32_config = dt.enumv() == DTypeEnum::Float32;
auto i8_dtype = DTypeEnum::QuantizedS8;
if (is_bias)
i8_dtype = DTypeEnum::QuantizedS32;
bool i8_config = dt.enumv() == i8_dtype;
return f32_config || i8_config;
};
// setup dtypes // setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) { for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32; bool is_bias =
ConvParamTrait<Opr>::has_bias && i == ConvParamTrait<Opr>::bias_idx;
available &= check_dtype(opr->input(i)->dtype(), is_bias);
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type); config.input_tensor_types.emplace_back(tensor_type);
} }
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32; available &= check_dtype(opr->output(0)->dtype(), false);
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE; available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
config.input_tensor_formats = { config.input_tensor_formats = {
...@@ -614,7 +629,6 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_HYBRID> { ...@@ -614,7 +629,6 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_HYBRID> {
} }
}; };
#if !MEGDNN_DISABLE_FLOAT16
template <typename Opr> template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88> { struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88> {
static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) { static Maybe<OprTensorFormatsConfiguration> dispatch(const OperatorNodeBase* opr) {
...@@ -626,12 +640,12 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88> { ...@@ -626,12 +640,12 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88> {
bool available = true; bool available = true;
// setup dtypes // setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) { for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32;
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type); config.input_tensor_types.emplace_back(tensor_type);
} }
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
// setup tensor formats // setup tensor formats
if (conv.param().sparse == Opr::Param::Sparse::DENSE) { if (conv.param().sparse == Opr::Param::Sparse::DENSE) {
...@@ -668,12 +682,12 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88_HYBRID> { ...@@ -668,12 +682,12 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88_HYBRID> {
bool available = true; bool available = true;
// setup dtypes // setup dtypes
for (size_t i = 0; i < opr->input().size(); ++i) { for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float16; available &= opr->input(i)->dtype().enumv() == DTypeEnum::Float32;
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv()); config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE; TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type); config.input_tensor_types.emplace_back(tensor_type);
} }
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float16; available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv()); config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE; available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
// setup tensor formats // setup tensor formats
...@@ -686,7 +700,6 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88_HYBRID> { ...@@ -686,7 +700,6 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88_HYBRID> {
return config; return config;
} }
}; };
#endif
template <typename Opr> template <typename Opr>
struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT> { struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT> {
...@@ -914,10 +927,8 @@ StaticData::StaticData() { ...@@ -914,10 +927,8 @@ StaticData::StaticData() {
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44);
#if !MEGDNN_DISABLE_FLOAT16
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88_HYBRID); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88_HYBRID);
#endif
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_HYBRID);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID); OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID);
...@@ -925,10 +936,8 @@ StaticData::StaticData() { ...@@ -925,10 +936,8 @@ StaticData::StaticData() {
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44);
#if !MEGDNN_DISABLE_FLOAT16
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88_HYBRID); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88_HYBRID);
#endif
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_HYBRID);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID); OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44_DOT_HYBRID);
...@@ -949,15 +958,11 @@ StaticData::StaticData() { ...@@ -949,15 +958,11 @@ StaticData::StaticData() {
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW32);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW64);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW44);
#if !MEGDNN_DISABLE_FLOAT16
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(PoolingForward, NCHW88);
#endif
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW);
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW44);
#if !MEGDNN_DISABLE_FLOAT16
OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88); OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG(ResizeForward, NCHW88);
#endif
#undef OPR_TENSOR_FORMATS_CONFIG_REG #undef OPR_TENSOR_FORMATS_CONFIG_REG
#undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG #undef OPR_SINGLE_IN_OUT_TENSOR_FORMATS_CONFIG_REG
......
...@@ -357,9 +357,8 @@ struct GraphTuningOptions { ...@@ -357,9 +357,8 @@ struct GraphTuningOptions {
enum class Target : uint32_t { enum class Target : uint32_t {
UNSPEC = 0, ///< unspecific device target UNSPEC = 0, ///< unspecific device target
CUDA = 1, ///< CUDA device, usually refer to GPU devices of Nvidia CUDA = 1, ///< CUDA device, usually refer to GPU devices of Nvidia
X86 = 2, ///< x86 cpu CPU = 2, ///< cpu
ARM = 3, ///< arm cpu OPENCL = 3, ///< opencl, usually run on mobile devices
OPENCL = 4, ///< opencl, usually run on mobile devices
}; };
Target target; Target target;
bool layout_transform = false; ///< whether to enable graph level bool layout_transform = false; ///< whether to enable graph level
......
此差异由.gitattributes 抑制。
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "megbrain/plugin/profiler.h" #include "megbrain/plugin/profiler.h"
#include "megbrain/serialization/serializer.h" #include "megbrain/serialization/serializer.h"
#define MGB_WITH_CACHED_TEST 0 #define MGB_WITH_CACHED_TEST 1
#if MGB_WITH_CACHED_TEST #if MGB_WITH_CACHED_TEST
#include "./cache_data.h" #include "./cache_data.h"
...@@ -923,9 +923,196 @@ TEST(TestLayoutTransform, MobileNetV2) { ...@@ -923,9 +923,196 @@ TEST(TestLayoutTransform, MobileNetV2) {
HostTensorND t2; HostTensorND t2;
auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)}); auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)});
func2->execute(); func2->execute();
gprof.to_json_full(func2.get())->writeto_fpath(output_file("mobilenet_v2_f32.json")); gprof.to_json_full(func2.get())
->writeto_fpath(output_file("mobilenet_v2_f32.json"));
/// check correct /// check correct
MGB_ASSERT_TENSOR_EQ(t1, t2); MGB_ASSERT_TENSOR_EQ(t1, t2);
} }
TEST(TestLayoutTransform, MobileNetV2_NCHW88) {
auto cn = CompNode::load("cpu0");
Network network(cn);
auto output = make_mobilenet_v2(network, 1);
HostTensorND t1;
auto func1 = network.graph->compile({make_callback_copy(output, t1)});
func1->execute();
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Target = LayoutTransformContext::Target;
using Attribute = LayoutTransformContext::Attribute;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ConvolutionForward::typeinfo(),
opr::ElemwiseMultiType::typeinfo(),
opr::Elemwise::typeinfo(),
opr::TypeCvt::typeinfo(),
opr::Concat::typeinfo(),
opr::PoolingForward::typeinfo(),
opr::WarpPerspectiveForward::typeinfo(),
opr::Resize::typeinfo(),
};
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW,
TensorFormats::NCHWc4,
TensorFormats::NCHWc8,
};
Attribute attribute = {
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{
OprFormatConfigID::NCHW88,
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW88_HYBRID,
})
.add_opr_config(
opr::ConvolutionForward::typeinfo(),
{
OprFormatConfigID::NCHW88,
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW88_HYBRID,
})
.add_opr_config(
opr::PoolingForward::typeinfo(), {
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW88,
});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(TestLayoutTransform_MobileNetV2_NCHW88.data()),
TestLayoutTransform_MobileNetV2_NCHW88.size());
#else
auto profiler = ProfilerBase::make_cached_profiler(
"TestLayoutTransform.MobileNetV2_NCHW88.cache");
#endif
std::unique_ptr<SolverBase> solver{
new DynamicProgrammingSolver(std::move(profiler))};
auto new_output =
gopt::GraphOptimizer{}
.add_pass<FuseConvBiasNonlinPass>()
.add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
.add_pass<ShuffleShuffleRemovePass>()
.add_pass<ParamFusePass>()
.add_pass<ParamMergePass>()
.apply({{output}})
.endpoint_vars();
auto new_out_var = new_output[0];
/// check global layout transform pass
auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var);
ASSERT_EQ(nr_dimshuffle, 1u);
/// check first conv format
const auto& first_conv = find_opr<opr::ConvBiasForward>(new_out_var);
const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>();
ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW88);
GraphProfiler gprof{network.graph.get()};
HostTensorND t2;
auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)});
func2->execute();
gprof.to_json_full(func2.get())
->writeto_fpath(output_file("mobilenet_v2_nchw88.json"));
/// check correct
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
TEST(TestLayoutTransform, MobileNetV2_NCHW44_DOT) {
auto cn = CompNode::load("cpu0");
Network network(cn);
auto output = make_mobilenet_v2(network, 1, dtype::QuantizedS8{1.f});
HostTensorND t1;
auto func1 = network.graph->compile({make_callback_copy(output, t1)});
func1->execute();
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Target = LayoutTransformContext::Target;
using Attribute = LayoutTransformContext::Attribute;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
opr::ConvolutionForward::typeinfo(),
opr::ElemwiseMultiType::typeinfo(),
opr::Elemwise::typeinfo(),
opr::TypeCvt::typeinfo(),
opr::Concat::typeinfo(),
opr::PoolingForward::typeinfo(),
opr::WarpPerspectiveForward::typeinfo(),
opr::Resize::typeinfo(),
};
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW,
TensorFormats::NCHWc4,
TensorFormats::NCHWc8,
};
Attribute attribute = {
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW44,
OprFormatConfigID::NCHW44_HYBRID,
OprFormatConfigID::NCHW44_DOT,
OprFormatConfigID::NCHW44_DOT_HYBRID,
})
.add_opr_config(
opr::ConvolutionForward::typeinfo(),
{
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW44,
OprFormatConfigID::NCHW44_HYBRID,
OprFormatConfigID::NCHW44_DOT,
OprFormatConfigID::NCHW44_DOT_HYBRID,
})
.add_opr_config(
opr::PoolingForward::typeinfo(), {
OprFormatConfigID::NCHW,
OprFormatConfigID::NCHW44,
});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(
TestLayoutTransform_MobileNetV2_NCHW44_DOT.data()),
TestLayoutTransform_MobileNetV2_NCHW44_DOT.size());
#else
auto profiler = ProfilerBase::make_cached_profiler(
"TestLayoutTransform.MobileNetV2_NCHW44_DOT.cache");
#endif
std::unique_ptr<SolverBase> solver{
new DynamicProgrammingSolver(std::move(profiler))};
auto new_output =
gopt::GraphOptimizer{}
.add_pass<FuseConvBiasNonlinPass>()
.add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
.add_pass<ShuffleShuffleRemovePass>()
.add_pass<ParamFusePass>()
.add_pass<ParamMergePass>()
.apply({{output}})
.endpoint_vars();
auto new_out_var = new_output[0];
/// check global layout transform pass
auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var);
ASSERT_EQ(nr_dimshuffle, 1u);
/// check first conv format
const auto& first_conv = find_opr<opr::ConvBiasForward>(new_out_var);
const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>();
ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW44_DOT);
GraphProfiler gprof{network.graph.get()};
HostTensorND t2;
auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)});
func2->execute();
gprof.to_json_full(func2.get())
->writeto_fpath(output_file("mobilenet_v2_nchw44_dot.json"));
/// check correct
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -57,7 +57,10 @@ SymbolVar Network::add_group_conv( ...@@ -57,7 +57,10 @@ SymbolVar Network::add_group_conv(
{groups, output_channels / groups, input_channels / groups, kern_size[0], {groups, output_channels / groups, input_channels / groups, kern_size[0],
kern_size[1]}); kern_size[1]});
auto bias = add_cvar(ssprintf("b%d", bias_idx).c_str(), {1, output_channels, 1, 1}); auto bias = add_cvar(ssprintf("b%d", bias_idx).c_str(), {1, output_channels, 1, 1});
mgb_assert(out_dtype.category() == DTypeCategory::FLOAT); if (out_dtype.category() == DTypeCategory::QUANTIZED) {
weight = add_type_cvt(weight, out_dtype);
bias = add_type_cvt(bias, dtype::QuantizedS32{1.f});
}
opr::ConvBias::Param param; opr::ConvBias::Param param;
param.sparse = opr::ConvBias::Param::Sparse::GROUP; param.sparse = opr::ConvBias::Param::Sparse::GROUP;
param.stride_h = stride[0], param.stride_w = stride[1]; param.stride_h = stride[0], param.stride_w = stride[1];
...@@ -68,8 +71,15 @@ SymbolVar Network::add_group_conv( ...@@ -68,8 +71,15 @@ SymbolVar Network::add_group_conv(
param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY; param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
} }
auto conv = opr::ConvBias::make( weight_idx++;
bias_idx++;
SymbolVar conv;
if (out_dtype.category() == DTypeCategory::QUANTIZED) {
conv = opr::ConvBias::make(
f, weight, bias, param, {}, OperatorNodeConfig{out_dtype}); f, weight, bias, param, {}, OperatorNodeConfig{out_dtype});
} else {
conv = opr::ConvBias::make(f, weight, bias, param, {});
}
weight_idx++; weight_idx++;
bias_idx++; bias_idx++;
return conv; return conv;
...@@ -269,17 +279,17 @@ SymbolVarArray mgb::make_det(Network& network, size_t batch, DType out_dtype) { ...@@ -269,17 +279,17 @@ SymbolVarArray mgb::make_det(Network& network, size_t batch, DType out_dtype) {
SymbolVar mgb::bottleneck( SymbolVar mgb::bottleneck(
Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t, Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t,
size_t stride) { size_t stride, DType out_dtype) {
size_t in_channels = f.node()->shape()[1]; size_t in_channels = f.node()->shape()[1];
SymbolVar x = f; SymbolVar x = f;
if (t != 1) { if (t != 1) {
x = network.add_conv( x = network.add_conv(
f, input_channels * t, {1, 1}, dtype::Float32(), true, {1, 1}, {0, 0}); f, input_channels * t, {1, 1}, out_dtype, true, {1, 1}, {0, 0});
} }
x = network.add_group_conv( x = network.add_group_conv(
x, input_channels * t, input_channels * t, {3, 3}, dtype::Float32(), true, x, input_channels * t, input_channels * t, {3, 3}, out_dtype, true,
{stride, stride}, {1, 1}); {stride, stride}, {1, 1});
x = network.add_conv(x, channels, {1, 1}, dtype::Float32(), false, {1, 1}, {0, 0}); x = network.add_conv(x, channels, {1, 1}, out_dtype, false, {1, 1}, {0, 0});
if (stride == 1 && in_channels == channels) if (stride == 1 && in_channels == channels)
x = f + x; x = f + x;
return x; return x;
...@@ -287,11 +297,11 @@ SymbolVar mgb::bottleneck( ...@@ -287,11 +297,11 @@ SymbolVar mgb::bottleneck(
SymbolVar mgb::bottleneck_group( SymbolVar mgb::bottleneck_group(
Network& network, SymbolVar f, size_t input_channels, size_t channels, Network& network, SymbolVar f, size_t input_channels, size_t channels,
size_t stages, size_t s, size_t t) { size_t stages, size_t s, size_t t, DType out_dtype) {
SymbolVar x = f; SymbolVar x = f;
for (size_t i = 0; i < stages; ++i) { for (size_t i = 0; i < stages; ++i) {
size_t stride = i == 0 ? s : 1; size_t stride = i == 0 ? s : 1;
x = bottleneck(network, x, input_channels, channels, t, stride); x = bottleneck(network, x, input_channels, channels, t, stride, out_dtype);
input_channels = channels; input_channels = channels;
} }
return x; return x;
...@@ -307,22 +317,34 @@ size_t make_divisible(size_t v, size_t divisor) { ...@@ -307,22 +317,34 @@ size_t make_divisible(size_t v, size_t divisor) {
} }
} // namespace } // namespace
SymbolVar mgb::make_mobilenet_v2(Network& network, size_t batch) { SymbolVar mgb::make_mobilenet_v2(Network& network, size_t batch, DType out_dtype) {
auto data = network.add_var("data", {batch, 3, 224, 224}); auto data = network.add_var("data", {batch, 3, 224, 224});
if (out_dtype.category() == DTypeCategory::QUANTIZED) {
data = network.add_type_cvt(data, dtype::QuantizedS8{1.f});
}
constexpr size_t round_nearest = 8; constexpr size_t round_nearest = 8;
auto x = network.add_conv( auto x = network.add_conv(
data, make_divisible(32, round_nearest), {3, 3}, dtype::Float32(), true, data, make_divisible(32, round_nearest), {3, 3}, out_dtype, true, {2, 2},
{2, 2}, {1, 1}); {1, 1});
x = bottleneck(network, x, 32, make_divisible(16, round_nearest), 1, 1); x = bottleneck(network, x, 32, make_divisible(16, round_nearest), 1, 1, out_dtype);
x = bottleneck_group(network, x, 16, make_divisible(24, round_nearest), 2, 2, 6); x = bottleneck_group(
x = bottleneck_group(network, x, 24, make_divisible(32, round_nearest), 3, 2, 6); network, x, 16, make_divisible(24, round_nearest), 2, 2, 6, out_dtype);
x = bottleneck_group(network, x, 32, make_divisible(64, round_nearest), 4, 2, 6); x = bottleneck_group(
x = bottleneck_group(network, x, 64, make_divisible(96, round_nearest), 3, 1, 6); network, x, 24, make_divisible(32, round_nearest), 3, 2, 6, out_dtype);
x = bottleneck_group(network, x, 96, make_divisible(160, round_nearest), 3, 2, 6); x = bottleneck_group(
x = bottleneck_group(network, x, 160, make_divisible(320, round_nearest), 1, 1, 6); network, x, 32, make_divisible(64, round_nearest), 4, 2, 6, out_dtype);
x = bottleneck_group(
network, x, 64, make_divisible(96, round_nearest), 3, 1, 6, out_dtype);
x = bottleneck_group(
network, x, 96, make_divisible(160, round_nearest), 3, 2, 6, out_dtype);
x = bottleneck_group(
network, x, 160, make_divisible(320, round_nearest), 1, 1, 6, out_dtype);
x = network.add_conv( x = network.add_conv(
x, make_divisible(1280, round_nearest), {1, 1}, dtype::Float32(), true, x, make_divisible(1280, round_nearest), {1, 1}, out_dtype, true, {1, 1},
{1, 1}, {0, 0}); {0, 0});
if (out_dtype.category() == DTypeCategory::QUANTIZED) {
x = network.add_type_cvt(x, dtype::Float32());
}
return x; return x;
} }
......
...@@ -79,13 +79,14 @@ SymbolVarArray make_det( ...@@ -79,13 +79,14 @@ SymbolVarArray make_det(
SymbolVar bottleneck( SymbolVar bottleneck(
Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t, Network& network, SymbolVar f, size_t input_channels, size_t channels, size_t t,
size_t stride); size_t stride, DType out_dtype = dtype::Float32());
SymbolVar bottleneck_group( SymbolVar bottleneck_group(
Network& network, SymbolVar f, size_t input_channels, size_t channels, Network& network, SymbolVar f, size_t input_channels, size_t channels,
size_t stages, size_t s, size_t t); size_t stages, size_t s, size_t t, DType out_dtype = dtype::Float32());
SymbolVar make_mobilenet_v2(Network& network, size_t batch = 1); SymbolVar make_mobilenet_v2(
Network& network, size_t batch = 1, DType out_dtype = dtype::Float32());
} // namespace mgb } // namespace mgb
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册