提交 eb18eba8 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

fix(gopt): fix nchw44 nchw44_dot gopt test

GitOrigin-RevId: 06b38dcd30249d363e3861dc8813a26bb20bf70f
上级 40e79e9d
...@@ -10,12 +10,12 @@ ...@@ -10,12 +10,12 @@
* implied. * implied.
*/ */
#include "src/fallback/convolution/opr_impl.h"
#include "src/common/algo_chooser.h" #include "src/common/algo_chooser.h"
#include "src/common/metahelper.h" #include "src/common/metahelper.h"
#include "src/common/opr_delegate.h" #include "src/common/opr_delegate.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/fallback/convolution/algos.h" #include "src/fallback/convolution/algos.h"
#include "src/fallback/convolution/opr_impl.h"
#include "src/fallback/convolution/run_conv.h" #include "src/fallback/convolution/run_conv.h"
#include "src/naive/convolution/helper.h" #include "src/naive/convolution/helper.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
...@@ -100,10 +100,10 @@ void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter, ...@@ -100,10 +100,10 @@ void ConvolutionImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_in filter,
} }
void ConvolutionImpl::exec_preprocess(const TensorLayout& src_layout, void ConvolutionImpl::exec_preprocess(const TensorLayout& src_layout,
_megdnn_tensor_in filter, _megdnn_tensor_in filter,
const TensorLayout& dst_layout, const TensorLayout& dst_layout,
PreprocessedFilter* preprocessed_filter, PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
//! exec_preprocess currently only support preprocess weights before exec, //! exec_preprocess currently only support preprocess weights before exec,
//! src/dst will be ignored, just set to nullptr //! src/dst will be ignored, just set to nullptr
TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout}; TensorND src{nullptr, src_layout}, dst{nullptr, dst_layout};
...@@ -151,7 +151,7 @@ size_t ConvolutionImpl::get_preprocess_workspace_in_bytes( ...@@ -151,7 +151,7 @@ size_t ConvolutionImpl::get_preprocess_workspace_in_bytes(
SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout( SmallVector<TensorLayout> ConvolutionImpl::deduce_preprocessed_filter_layout(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst){ const TensorLayout& dst) {
auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr); auto fparam = make_ncb_kern_size_param(src, filter, dst, nullptr);
Algorithm* algo = get_algorithm(fparam); Algorithm* algo = get_algorithm(fparam);
if (is_naive_algo(algo)) { if (is_naive_algo(algo)) {
...@@ -257,7 +257,8 @@ void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param, ...@@ -257,7 +257,8 @@ void ConvolutionImpl::exec_preprocess_with_ncb_kern(const NCBKernParam& param,
param.filter_meta.format == Param::Format::NCHW || param.filter_meta.format == Param::Format::NCHW ||
param.filter_meta.format == Param::Format::NHWC || param.filter_meta.format == Param::Format::NHWC ||
param.filter_meta.format == Param::Format::NCHW88 || param.filter_meta.format == Param::Format::NCHW88 ||
param.filter_meta.format == Param::Format::NCHW44, param.filter_meta.format == Param::Format::NCHW44 ||
param.filter_meta.format == Param::Format::NCHW44_DOT,
"invalid conv format"); "invalid conv format");
auto run = [param, kernel](size_t index, size_t thread_id) { auto run = [param, kernel](size_t index, size_t thread_id) {
CpuNDRange ndrange_id(kernel.global_size, index); CpuNDRange ndrange_id(kernel.global_size, index);
...@@ -277,7 +278,8 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param, ...@@ -277,7 +278,8 @@ void ConvolutionImpl::exec_with_ncb_kern(const NCBKernParam& param,
param.filter_meta.format == Param::Format::NCHW || param.filter_meta.format == Param::Format::NCHW ||
param.filter_meta.format == Param::Format::NHWC || param.filter_meta.format == Param::Format::NHWC ||
param.filter_meta.format == Param::Format::NCHW88 || param.filter_meta.format == Param::Format::NCHW88 ||
param.filter_meta.format == Param::Format::NCHW44, param.filter_meta.format == Param::Format::NCHW44 ||
param.filter_meta.format == Param::Format::NCHW44_DOT,
"invalid conv format"); "invalid conv format");
auto run = [param, kernel](size_t index, size_t thread_id) { auto run = [param, kernel](size_t index, size_t thread_id) {
CpuNDRange ndrange_id(kernel.global_size, index); CpuNDRange ndrange_id(kernel.global_size, index);
......
...@@ -31,6 +31,8 @@ ...@@ -31,6 +31,8 @@
#include "megdnn/opr_param_defs.h" #include "megdnn/opr_param_defs.h"
#include "megdnn/tensor_format.h" #include "megdnn/tensor_format.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#if MGB_ENABLE_TENSOR_RT #if MGB_ENABLE_TENSOR_RT
#include "megbrain/tensorrt/tensorrt_opr.h" #include "megbrain/tensorrt/tensorrt_opr.h"
#endif #endif
...@@ -392,7 +394,8 @@ void TensorReformatPass::insert_pass(OptState& opt) const { ...@@ -392,7 +394,8 @@ void TensorReformatPass::insert_pass(OptState& opt) const {
auto new_opr = (it->second)(opr, new_inp); auto new_opr = (it->second)(opr, new_inp);
auto &&out0 = opr->output(), &&out1 = new_opr->output(); auto &&out0 = opr->output(), &&out1 = new_opr->output();
mgb_assert(out0.size() == out1.size(), mgb_assert(out0.size() == out1.size(),
"bad opr replace: src=%s{%s} dst=%s{%s}, src.size=%zu " "bad opr replace: src=%s{%s} dst=%s{%s}, "
"src.size=%zu "
"dst.size=%zu", "dst.size=%zu",
opr->cname(), opr->dyn_typeinfo()->name, opr->cname(), opr->dyn_typeinfo()->name,
new_opr->cname(), new_opr->dyn_typeinfo()->name, new_opr->cname(), new_opr->dyn_typeinfo()->name,
...@@ -1811,14 +1814,156 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var, ...@@ -1811,14 +1814,156 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
} }
return new_var; return new_var;
} }
//! nchw_nchwxx_valid is used to indicate optimized nchw_nchw44 conv
static inline bool nchw_nchwxx_valid(const size_t oc, const size_t ic, template <typename OprType>
const size_t pack_c_size, const size_t fh, static inline bool nchw_nchwxx_valid(const OprType& opr,
const size_t fw, const size_t stride_h, const VarNodeArray& new_inp,
const size_t stride_w) { const size_t pack_size, bool is_dense,
return ic < pack_c_size && oc % pack_c_size == 0 && fh == fw && bool is_dot = false);
stride_h == stride_w && (stride_h == 1 || stride_h == 2) && template <>
(fh == 2 || fh == 3 || fh == 5 || fh == 7); inline bool nchw_nchwxx_valid<opr::ConvolutionForward>(
const opr::ConvolutionForward& opr, const VarNodeArray& new_inp,
const size_t pack_size, bool is_dense, bool is_dot) {
auto& filter_shape = new_inp[1]->shape();
auto filter_dtype = new_inp[1]->dtype();
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 ||
filter_dtype.enumv() == DTypeEnum::Int8;
const size_t oc = filter_shape[0];
const size_t ic = filter_shape[1];
bool is_like_nchw_nchwxx =
is_dense && oc % pack_size == 0 && ic < pack_size;
if (!is_like_nchw_nchwxx) {
return false;
}
SmallVector<TensorLayout> layouts;
//! src
layouts.push_back(
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()});
//! weight
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2],
filter_shape[3], filter_shape[1], pack_size},
new_inp[1]->dtype(),
new_inp[1]->format()});
auto out0 = opr.output(0);
auto& out_shape = out0->shape();
//! FIXME: return false if oc is invalid
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2],
out_shape[3], pack_size},
out0->dtype(),
out0->format()});
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node())
->create_operator<megdnn::ConvolutionForward>();
megdnn_conv.get()->param() = opr.param();
//! set by dtype
switch (pack_size) {
case 4:
if (is_dot && is_int8) {
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW44_DOT;
} else {
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW44;
}
break;
case 8:
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW88;
break;
default:
break;
}
bool find_valid_algo = false;
auto algos = megdnn_conv.get()->get_all_algorithms(layouts[0], layouts[1],
layouts[2]);
for (auto i : algos) {
if (i->type() != nullptr) {
find_valid_algo = true;
}
}
return find_valid_algo;
}
template <>
inline bool nchw_nchwxx_valid<opr::ConvBiasForward>(
const opr::ConvBiasForward& opr, const VarNodeArray& new_inp,
const size_t pack_size, bool is_dense, bool is_dot) {
auto& filter_shape = new_inp[1]->shape();
auto filter_dtype = new_inp[1]->dtype();
bool is_int8 = filter_dtype.enumv() == DTypeEnum::QuantizedS8 ||
filter_dtype.enumv() == DTypeEnum::Int8;
const size_t oc = filter_shape[0];
const size_t ic = filter_shape[1];
bool is_like_nchw_nchwxx =
is_dense && oc % pack_size == 0 && ic < pack_size;
if (!is_like_nchw_nchwxx) {
return false;
}
SmallVector<TensorLayout> layouts;
//! src
layouts.push_back(
{new_inp[0]->shape(), new_inp[0]->dtype(), new_inp[0]->format()});
//! weight
layouts.push_back({{filter_shape[0] / pack_size, filter_shape[2],
filter_shape[3], filter_shape[1], pack_size},
new_inp[1]->dtype(),
new_inp[1]->format()});
auto& bias_shape = new_inp[2]->shape();
layouts.push_back({{bias_shape[0], bias_shape[1] / pack_size, bias_shape[2],
bias_shape[3], pack_size},
new_inp[2]->dtype(),
new_inp[2]->format()});
auto out0 = opr.output(0);
auto& out_shape = out0->shape();
//! FIXME: return false if oc is invalid
layouts.push_back({{out_shape[0], out_shape[1] / pack_size, out_shape[2],
out_shape[3], pack_size},
out0->dtype(),
out0->format()});
// megdnn::ConvolutionForward
auto megdnn_conv = opr::intl::get_megdnn_handle(opr.comp_node())
->create_operator<megdnn::ConvBiasForward>();
megdnn_conv.get()->param() = opr.param();
//! FIXME: set by dtype
switch (pack_size) {
case 4:
if (is_dot && is_int8) {
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW44_DOT;
} else {
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW44;
}
break;
case 8:
megdnn_conv.get()->param().format =
megdnn::param::Convolution::Format::NCHW88;
break;
default:
break;
}
bool find_valid_algo = false;
auto algos = megdnn_conv.get()->get_all_algorithms(
layouts[0], layouts[1], layouts[2], {}, layouts[3]);
for (auto i : algos) {
if (i->type() != nullptr) {
find_valid_algo = true;
}
}
return find_valid_algo;
} }
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
using RelayoutMode = RelayoutPlaceholder::LayoutType; using RelayoutMode = RelayoutPlaceholder::LayoutType;
...@@ -1839,6 +1984,20 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -1839,6 +1984,20 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
megdnn::param::Pooling::Format pooling_format = megdnn::param::Pooling::Format pooling_format =
megdnn::param::Pooling::Format::NCHW88; megdnn::param::Pooling::Format::NCHW88;
std::string convter_pass_name = "conv_format_nchw88"; std::string convter_pass_name = "conv_format_nchw88";
#if MEGDNN_AARCH64 || MEGDNN_ARMv7
if (pack_c_size == 8) {
mgb_log_error(
"runtime backend is ARM, but nchw88 only support X86, you may "
"have performance loss\n");
}
#elif MEGDNN_X86
if (pack_c_size == 4) {
mgb_log_error(
"runtime backend is X86, but nchw44 only support arm, you may "
"have performance loss\n");
}
#endif
if (pack_c_size == 4) { if (pack_c_size == 4) {
weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE; weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE;
weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP; weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP;
...@@ -1857,18 +2016,16 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -1857,18 +2016,16 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
hybrid_nchw_nchwxx]( hybrid_nchw_nchwxx](
const megdnn::param::Convolution::Sparse conv_mode, const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter, const size_t stride_h, const VarNode* filter, const size_t stride_h,
const size_t stride_w) -> TestFilterResult { const size_t stride_w,
bool valid_nchw_nchw44) -> TestFilterResult {
TestFilterResult ret{TransType::TRANS_NONE, {}}; TestFilterResult ret{TransType::TRANS_NONE, {}};
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
size_t OC = filter->shape()[0]; size_t OC = filter->shape()[0];
size_t IC = filter->shape()[1]; size_t IC = filter->shape()[1];
size_t FH = filter->shape()[2];
size_t FW = filter->shape()[3];
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
ret.first = TransType::TRANS_PURE_NCHWXX; ret.first = TransType::TRANS_PURE_NCHWXX;
ret.second = weight_to_nchwxx_mode_dense; ret.second = weight_to_nchwxx_mode_dense;
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, } else if (valid_nchw_nchw44) {
stride_w)) {
ret.first = TransType::TRANS_HYBIRD_NCHWXX; ret.first = TransType::TRANS_HYBIRD_NCHWXX;
ret.second = hybrid_nchw_nchwxx; ret.second = hybrid_nchw_nchwxx;
} }
...@@ -1888,16 +2045,21 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -1888,16 +2045,21 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
return ret; return ret;
}; };
auto replace_conv_opr = [test_trans_nchwxx, conv_format, src_to_nchwxx_mode, auto replace_conv_opr = [test_trans_nchwxx, conv_format, src_to_nchwxx_mode,
src_to_nchw_mode](OperatorNodeBase* opr, src_to_nchw_mode,
const VarNodeArray& new_inp) { pack_c_size](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
mgb_assert(conv_opr.param().format == mgb_assert(conv_opr.param().format ==
megdnn::param::Convolution::Format::NCHW, megdnn::param::Convolution::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX"); "ConvertFormat Pass only support converting NCHW to NCHWXX");
auto is_trans = test_trans_nchwxx(conv_opr.param().sparse, new_inp[1], bool is_dense = conv_opr.param().sparse ==
conv_opr.param().stride_h, megdnn::param::Convolution::Sparse::DENSE;
conv_opr.param().stride_w); bool valid_nchw_nchw44 =
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense);
auto is_trans = test_trans_nchwxx(
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h,
conv_opr.param().stride_w, valid_nchw_nchw44);
//! can not trans to nchwxx //! can not trans to nchwxx
if (is_trans.first == TransType::TRANS_NONE) { if (is_trans.first == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 || mgb_assert(new_inp[1]->shape().ndim == 4 ||
...@@ -1963,17 +2125,23 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) { ...@@ -1963,17 +2125,23 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
}; };
auto replace_conv_bias_opr = [test_trans_nchwxx, conv_bias_format, auto replace_conv_bias_opr = [test_trans_nchwxx, conv_bias_format,
src_to_nchwxx_mode, src_to_nchw_mode]( src_to_nchwxx_mode, src_to_nchw_mode,
OperatorNodeBase* opr, pack_c_size](OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
mgb_assert(conv_bias_opr.param().format == mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW, megdnn::param::ConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX"); "ConvertFormat Pass only support converting NCHW to NCHWXX");
bool is_dense = conv_bias_opr.param().sparse ==
megdnn::param::Convolution::Sparse::DENSE;
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp,
pack_c_size, is_dense);
auto is_trans = test_trans_nchwxx( auto is_trans = test_trans_nchwxx(
conv_bias_opr.param().sparse, new_inp[1], conv_bias_opr.param().sparse, new_inp[1],
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w,
valid_nchw_nchw44);
//! can not trans to nchwxx //! can not trans to nchwxx
if (is_trans.first == TransType::TRANS_NONE) { if (is_trans.first == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 || mgb_assert(new_inp[1]->shape().ndim == 4 ||
...@@ -2203,8 +2371,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2203,8 +2371,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
MIDOUT_B("EnableNchw44DotPass::make") MIDOUT_B("EnableNchw44DotPass::make")
auto ret = std::make_unique<EnableNchw44DotPass>(); auto ret = std::make_unique<EnableNchw44DotPass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK); ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
//! First is whether the conv can trans to nchwxx, second is the filter //! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode //! trans mode
#if MEGDNN_X86
mgb_log_error(
"backend is X86, but nchw44_dot only support arm, you may have "
"performance loss\n");
#endif
using RelayoutMode = RelayoutPlaceholder::LayoutType; using RelayoutMode = RelayoutPlaceholder::LayoutType;
struct TestTransResult { struct TestTransResult {
TransType trans_type; TransType trans_type;
...@@ -2215,23 +2389,35 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2215,23 +2389,35 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
auto test_trans_nchw44_dot = auto test_trans_nchw44_dot =
[](const megdnn::param::Convolution::Sparse conv_mode, [](const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter, const size_t stride_h, const VarNode* filter, const size_t stride_h,
const size_t stride_w) -> TestTransResult { const size_t stride_w,
const bool valid_nchw_nchw44) -> TestTransResult {
TestTransResult ret{TransType::TRANS_NONE, {}, {}}; TestTransResult ret{TransType::TRANS_NONE, {}, {}};
bool is_int8 = filter->dtype().enumv() == DTypeEnum::QuantizedS8 ||
filter->dtype().enumv() == DTypeEnum::Int8;
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
size_t OC = filter->shape()[0]; size_t OC = filter->shape()[0];
size_t IC = filter->shape()[1]; size_t IC = filter->shape()[1];
size_t FH = filter->shape()[2];
size_t FW = filter->shape()[3];
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) { if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
ret.trans_type = TransType::TRANS_PURE_NCHWXX; ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod = if (is_int8) {
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE; ret.relayout_mod =
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE;
} else if (nchw_nchwxx_valid(OC, IC, pack_c_size, FH, FW, stride_h, ret.conv_format =
stride_w)) { megdnn::param::ConvBias::Format::NCHW44_DOT;
} else {
ret.relayout_mod =
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
}
} else if (valid_nchw_nchw44) {
ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX; ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44; ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; if (is_int8) {
ret.conv_format =
megdnn::param::ConvBias::Format::NCHW44_DOT;
} else {
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
}
} }
} else { } else {
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
...@@ -2244,9 +2430,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2244,9 +2430,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44; ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
} else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) { } else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) {
ret.trans_type = TransType::TRANS_PURE_NCHWXX; ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod = if (is_int8) {
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP; ret.relayout_mod =
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT; RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP;
ret.conv_format =
megdnn::param::ConvBias::Format::NCHW44_DOT;
} else {
ret.relayout_mod =
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
}
} }
} }
return ret; return ret;
...@@ -2260,9 +2453,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2260,9 +2453,14 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
megdnn::param::Convolution::Format::NCHW, megdnn::param::Convolution::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to " "ConvertFormat Pass only support converting NCHW to "
"NCHW44_DOT"); "NCHW44_DOT");
bool is_dense = conv_opr.param().sparse ==
megdnn::param::Convolution::Sparse::DENSE;
bool valid_nchw_nchw44 =
nchw_nchwxx_valid(conv_opr, new_inp, pack_c_size, is_dense);
auto is_trans = test_trans_nchw44_dot( auto is_trans = test_trans_nchw44_dot(
conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h, conv_opr.param().sparse, new_inp[1], conv_opr.param().stride_h,
conv_opr.param().stride_w); conv_opr.param().stride_w, valid_nchw_nchw44);
//! can not trans to nchwxx //! can not trans to nchwxx
if (is_trans.trans_type == TransType::TRANS_NONE) { if (is_trans.trans_type == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 || mgb_assert(new_inp[1]->shape().ndim == 4 ||
...@@ -2335,9 +2533,19 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2335,9 +2533,19 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
mgb_assert(conv_bias_opr.param().format == mgb_assert(conv_bias_opr.param().format ==
megdnn::param::ConvBias::Format::NCHW, megdnn::param::ConvBias::Format::NCHW,
"ConvertFormat Pass only support converting NCHW to NCHWXX"); "ConvertFormat Pass only support converting NCHW to NCHWXX");
bool is_dense = conv_bias_opr.param().sparse ==
megdnn::param::Convolution::Sparse::DENSE;
bool valid_nchw_nchw44 = nchw_nchwxx_valid(conv_bias_opr, new_inp,
pack_c_size, is_dense);
auto is_trans = test_trans_nchw44_dot( auto is_trans = test_trans_nchw44_dot(
conv_bias_opr.param().sparse, new_inp[1], conv_bias_opr.param().sparse, new_inp[1],
conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w); conv_bias_opr.param().stride_h, conv_bias_opr.param().stride_w,
valid_nchw_nchw44);
auto megdnn_conv =
opr::intl::get_megdnn_handle(conv_bias_opr.comp_node())
->create_operator<megdnn::ConvBiasForward>();
SmallVector<TensorLayout> layouts;
//! can not trans to nchwxx //! can not trans to nchwxx
if (is_trans.trans_type == TransType::TRANS_NONE) { if (is_trans.trans_type == TransType::TRANS_NONE) {
mgb_assert(new_inp[1]->shape().ndim == 4 || mgb_assert(new_inp[1]->shape().ndim == 4 ||
...@@ -2350,6 +2558,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() { ...@@ -2350,6 +2558,7 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
new_inp[0], RelayoutMode::NCHW4_TO_NCHW); new_inp[0], RelayoutMode::NCHW4_TO_NCHW);
temp_inp[0] = new_src.node(); temp_inp[0] = new_src.node();
} }
//! the bias is nchwxx //! the bias is nchwxx
if (temp_inp[2]->shape().ndim == 5) { if (temp_inp[2]->shape().ndim == 5) {
auto new_bias = RelayoutPlaceholder::make( auto new_bias = RelayoutPlaceholder::make(
......
...@@ -32,6 +32,7 @@ ...@@ -32,6 +32,7 @@
#include "./helper.h" #include "./helper.h"
#include "megbrain/comp_node_env.h" #include "megbrain/comp_node_env.h"
#include "cpuinfo.h"
#include "megdnn/tensor_format.h" #include "megdnn/tensor_format.h"
#include <random> #include <random>
...@@ -49,7 +50,21 @@ T& find_opr(SymbolVar endpoint) { ...@@ -49,7 +50,21 @@ T& find_opr(SymbolVar endpoint) {
} }
}; };
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr()); cg::DepOprIter{cb}.add(endpoint.node()->owner_opr());
mgb_assert(found); mgb_assert(found, "not found opr from %s", endpoint.node()->name().c_str());
return *found;
}
template <typename T>
T& find_opr(SymbolVar endpoint, const std::string& node_name) {
T* found = nullptr;
auto cb = [&found, &node_name](cg::OperatorNodeBase* opr) {
if (!found && opr->same_type<T>() && opr->name() == node_name) {
found = &opr->cast_final_safe<T>();
}
};
cg::DepOprIter{cb}.add(endpoint.node()->owner_opr());
mgb_assert(found, "not found opr %s from %s", node_name.c_str(),
endpoint.node()->name().c_str());
return *found; return *found;
} }
...@@ -2973,25 +2988,48 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { ...@@ -2973,25 +2988,48 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name); .rename(name);
}; };
auto mkcvar_dtype = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};
auto host_x = gen({2, 3, 16, 16}, cn); auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//! Hybrid nchw44 mode //! Hybrid nchw44 mode
opr::Convolution::Param param_conv; opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1; param_conv.pad_h = param_conv.pad_w = 1;
opr::ConvBias::Param param_conv_bias_stride4;
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4;
auto w1 = mkcvar("w1", {8, 3, 3, 3}), auto w1 = mkcvar("w1", {8, 3, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv); conv1 = opr::Convolution::make(x, w1, param_conv, {},
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), OperatorNodeConfig("conv1"));
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv);
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); //! no supported hybrid nchw44
//! channel wise opr::ConvBias::Param param_conv_bias_pad0;
param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0;
auto b1 = mkcvar("b1", {1, 8, 1, 1});
auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1});
auto conv1_f1 = opr::ConvBias::make(x, w1_f1, b1, param_conv_bias_pad0, {},
OperatorNodeConfig("conv1_f1"));
auto conv1_add = conv1_f1 * conv1;
auto conv_1_q8 = opr::TypeCvt::make(conv1_add, dtype::QuantizedS8(2.5f));
//! s8 dense conv
opr::ConvBias::Param param_conv_bias; opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f));
auto b1_2 = mkcvar_dtype("b1_2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv_1_2 = opr::ConvBias::make(
conv_1_q8, w1_2, b1_2, param_conv_bias, {},
OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}});
auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32());
//! channel wise
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}),
conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias); conv2 = opr::ConvBias::make(conv_1_2_fp32, w2, b2, param_conv_bias);
//! group //! group
auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}),
conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias);
...@@ -3013,42 +3051,67 @@ TEST(TestGoptInference, ConvertFormatNCHW44) { ...@@ -3013,42 +3051,67 @@ TEST(TestGoptInference, ConvertFormatNCHW44) {
//! Dense //! Dense
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), auto w3_2 = mkcvar("w3_2", {16, 8, 3, 3}),
conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias); b3_2 = mkcvar("b3_2", {1, 16, 1, 1}),
conv3_2 = opr::ConvBias::make(elem, w3_2, b3_2, param_conv_bias, {},
OperatorNodeConfig("conv3_2"));
//! s8 group conv
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto conv3_2_q8 = opr::TypeCvt::make(conv3_2, dtype::QuantizedS8(2.5f));
auto w3_3 = mkcvar_dtype("w3_3", {4, 8, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b3_3 = mkcvar_dtype("b3_3", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)),
conv3_3_q = opr::ConvBias::make(
conv3_2_q8, w3_3, b3_3, param_conv_bias, {},
OperatorNodeConfig{"conv_3_3_q", cn,
dtype::QuantizedS8{6.25f}});
auto conv3_3 = opr::TypeCvt::make(conv3_3_q, dtype::Float32());
//! Dense
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
auto w4 = mkcvar("w4", {4, 32, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}),
conv4 = opr::ConvBias::make(conv3_3, w4, b4, param_conv_bias, {},
OperatorNodeConfig("conv4"));
auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}),
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias); conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias, {},
OperatorNodeConfig("conv5"));
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}),
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias, {},
OperatorNodeConfig("conv6"));
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt; SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw44(); options.enable_nchw44();
unpack_vector(gopt::optimize_for_inference( unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
{y, conv1, conv1_f4, conv1_s4, conv2}, options),
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt, conv2_opt);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44,
find_opr<opr::Convolution>(conv1_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW,
find_opr<opr::ConvBias>(conv1_s4_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW,
find_opr<opr::Convolution>(conv1_f4_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44,
find_opr<opr::ConvBias>(conv2_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}, {conv2, {}}}) #if MEGDNN_AARCH64 || MEGDNN_ARMV7
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::Convolution>(y_opt, "conv1").param().format);
#else
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::Convolution>(y_opt, "conv1").param().format);
#endif
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt, "conv_1_2").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt, "conv3_2").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt, "conv_3_3_q").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt, "conv4").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::ConvBias>(y_opt, "conv5").param().format);
graph->compile({{y_opt, {}}})
->to_json() ->to_json()
->writeto_fpath( ->writeto_fpath(
output_file("TestGoptInference.ConvertFormatNCHW44.json")); output_file("TestGoptInference.ConvertFormatNCHW44.json"));
HostTensorND host_y_opt, host_y; HostTensorND host_y_opt, host_y;
HostTensorND host_conv1;
auto func = graph->compile({make_callback_copy(y, host_y), auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt), make_callback_copy(y_opt, host_y_opt)});
make_callback_copy(conv1, host_conv1)});
func->execute(); func->execute();
//! meybe go to winograd in x86-32, so set error 1e-1 //! meybe go to winograd in x86-32, so set error 1e-1
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-1);
...@@ -3155,25 +3218,58 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { ...@@ -3155,25 +3218,58 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)) return opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name); .rename(name);
}; };
auto mkcvar_dtype = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn))
.rename(name),
dtype);
};
auto host_x = gen({2, 3, 16, 16}, cn); auto host_x = gen({2, 3, 16, 16}, cn);
auto x = opr::Host2DeviceCopy::make(*graph, host_x); auto x = opr::Host2DeviceCopy::make(*graph, host_x);
//! Hybrid nchw44 mode //! Hybrid nchw44 mode
opr::Convolution::Param param_conv; opr::Convolution::Param param_conv;
param_conv.pad_h = param_conv.pad_w = 1; param_conv.pad_h = param_conv.pad_w = 1;
opr::ConvBias::Param param_conv_bias_stride4;
param_conv_bias_stride4.stride_h = param_conv_bias_stride4.stride_w = 4;
auto w1 = mkcvar("w1", {8, 3, 3, 3}), auto w1 = mkcvar("w1", {8, 3, 3, 3}),
conv1 = opr::Convolution::make(x, w1, param_conv); conv1 = opr::Convolution::make(x, w1, param_conv, {},
auto w1_1 = mkcvar("w1_1", {8, 3, 4, 4}), b1 = mkcvar("b2", {1, 8, 1, 1}), OperatorNodeConfig("conv1"));
conv1_f4 = opr::Convolution::make(x, w1_1, param_conv); printf("create conv1 %s\n",
auto conv1_s4 = opr::ConvBias::make(x, w1, b1, param_conv_bias_stride4); conv1.node()->owner_opr()->dyn_typeinfo()->name);
//! channel wise param_conv.pad_h = param_conv.pad_w = 1;
//! no supported hybrid nchw44
opr::ConvBias::Param param_conv_bias_pad0;
param_conv_bias_pad0.pad_h = param_conv_bias_pad0.pad_w = 0;
auto b1 = mkcvar("b1", {1, 8, 1, 1});
auto w1_f1 = mkcvar("w1_1", {8, 3, 1, 1});
auto conv1_f1 = opr::ConvBias::make(x, w1_f1, b1, param_conv_bias_pad0, {},
OperatorNodeConfig("conv1_f1"));
//! hybrid dot
auto x_s = opr::TypeCvt::make(x, dtype::QuantizedS8(2.5f));
auto w1_3 = mkcvar_dtype("w1_3", {8, 3, 3, 3}, dtype::QuantizedS8(2.5f));
auto conv1_3_q = opr::Convolution::make(
x_s, w1_3, param_conv, {},
OperatorNodeConfig{"conv1_3_q", cn, dtype::QuantizedS8{6.25f}});
auto conv1_3 = opr::TypeCvt::make(conv1_3_q, dtype::Float32());
auto conv1_add = conv1_f1 * conv1 * conv1_3;
auto conv_1_q8 = opr::TypeCvt::make(conv1_add, dtype::QuantizedS8(2.5f));
//! s8 dense conv
opr::ConvBias::Param param_conv_bias; opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
auto w1_2 = mkcvar_dtype("w1_2", {8, 8, 3, 3}, dtype::QuantizedS8(2.5f));
auto b1_2 = mkcvar_dtype("b1_2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv_1_2 = opr::ConvBias::make(
conv_1_q8, w1_2, b1_2, param_conv_bias, {},
OperatorNodeConfig{"conv_1_2", cn, dtype::QuantizedS8{6.25f}});
auto conv_1_2_fp32 = opr::TypeCvt::make(conv_1_2, dtype::Float32());
//! channel wise
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP; param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}), auto w2 = mkcvar("w2", {8, 1, 1, 3, 3}), b2 = mkcvar("b2", {1, 8, 1, 1}),
conv2 = opr::ConvBias::make(conv1, w2, b2, param_conv_bias); conv2 = opr::ConvBias::make(conv_1_2_fp32, w2, b2, param_conv_bias);
//! group //! group
auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}), auto w3 = mkcvar("w3", {2, 4, 4, 3, 3}), b3 = mkcvar("b3", {1, 8, 1, 1}),
conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias); conv3 = opr::ConvBias::make(conv2, w3, b3, param_conv_bias);
...@@ -3195,35 +3291,68 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) { ...@@ -3195,35 +3291,68 @@ TEST(TestGoptInference, ConvertFormatNCHW44_DOT) {
//! Dense //! Dense
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE; param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1; param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
auto w4 = mkcvar("w4", {4, 8, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}), auto w3_2 = mkcvar("w3_2", {16, 8, 3, 3}),
conv4 = opr::ConvBias::make(elem, w4, b4, param_conv_bias); b3_2 = mkcvar("b3_2", {1, 16, 1, 1}),
conv3_2 = opr::ConvBias::make(elem, w3_2, b3_2, param_conv_bias, {},
OperatorNodeConfig("conv3_2"));
//! s8 group conv
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto conv3_2_q8 = opr::TypeCvt::make(conv3_2, dtype::QuantizedS8(2.5f));
auto w3_3 = mkcvar_dtype("w3_3", {4, 8, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b3_3 = mkcvar_dtype("b3_3", {1, 32, 1, 1}, dtype::QuantizedS32(6.25f)),
conv3_3_q = opr::ConvBias::make(
conv3_2_q8, w3_3, b3_3, param_conv_bias, {},
OperatorNodeConfig{"conv_3_3_q", cn,
dtype::QuantizedS8{6.25f}});
auto conv3_3 = opr::TypeCvt::make(conv3_3_q, dtype::Float32());
//! Dense
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
auto w4 = mkcvar("w4", {4, 32, 3, 3}), b4 = mkcvar("b4", {1, 4, 1, 1}),
conv4 = opr::ConvBias::make(conv3_3, w4, b4, param_conv_bias, {},
OperatorNodeConfig("conv4"));
auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}), auto w5 = mkcvar("w5", {6, 4, 3, 3}), b5 = mkcvar("b5", {1, 6, 1, 1}),
conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias); conv5 = opr::ConvBias::make(conv4, w5, b5, param_conv_bias, {},
OperatorNodeConfig("conv5"));
auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}), auto w6 = mkcvar("w6", {4, 6, 3, 3}), b6 = mkcvar("b6", {1, 4, 1, 1}),
y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias); y = opr::ConvBias::make(conv5, w6, b6, param_conv_bias, {},
OperatorNodeConfig("conv6"));
SymbolVar y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt; SymbolVar y_opt;
auto options = gopt::OptimizeForInferenceOptions{}; auto options = gopt::OptimizeForInferenceOptions{};
options.enable_fuse_conv_bias_nonlinearity();
options.enable_nchw44_dot(); options.enable_nchw44_dot();
unpack_vector(gopt::optimize_for_inference({y, conv1, conv1_f4, conv1_s4}, unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
options),
y_opt, conv1_opt, conv1_f4_opt, conv1_s4_opt); #if MEGDNN_AARCH64 || MEGDNN_ARMV7
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT,
find_opr<opr::Convolution>(conv1_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW,
find_opr<opr::ConvBias>(conv1_s4_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW,
find_opr<opr::Convolution>(conv1_f4_opt).param().format);
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW44_DOT,
find_opr<opr::Convolution>(y_opt).param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44, ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt).param().format); find_opr<opr::Convolution>(y_opt, "conv1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT,
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format);
#else
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::Convolution>(y_opt, "conv1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::Convolution>(y_opt, "conv1_3_q").param().format);
#endif
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::ConvBias>(y_opt, "conv1_f1").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT,
find_opr<opr::ConvBias>(y_opt, "conv_1_2").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt, "conv3_2").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44_DOT,
find_opr<opr::ConvBias>(y_opt, "conv_3_3_q").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW44,
find_opr<opr::ConvBias>(y_opt, "conv4").param().format);
ASSERT_EQ(opr::Convolution::Param::Format::NCHW,
find_opr<opr::ConvBias>(y_opt, "conv5").param().format);
graph->compile({{y_opt, {}}}) graph->compile({{y_opt, {}}})
->to_json() ->to_json()
->writeto_fpath( ->writeto_fpath(output_file(
output_file("TestGoptInference.ConvertFormatNCHW44.json")); "TestGoptInference.ConvertFormatNCHW44_DOT.json"));
HostTensorND host_y_opt, host_y; HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y), auto func = graph->compile({make_callback_copy(y, host_y),
......
...@@ -608,11 +608,16 @@ public: ...@@ -608,11 +608,16 @@ public:
auto algo = get_algo(ctx); auto algo = get_algo(ctx);
size_t workspace = ctx.get_workspace_size_bytes(algo); size_t workspace = ctx.get_workspace_size_bytes(algo);
mgb_log_debug( mgb_log_debug(
"%s: input shapes (%s, %s): algo=%s " "%s: input shapes (%s %s, %s %s) -> (%s %s): algo=%s "
"workspace=%.2fMiB reproducible=%d", "workspace=%.2fMiB reproducible=%d",
mgb_opr->dyn_typeinfo()->name, mgb_opr->dyn_typeinfo()->name,
layouts[0].TensorShape::to_string().c_str(), layouts[0].TensorShape::to_string().c_str(),
layouts[1].TensorShape::to_string().c_str(), algo->name(), layouts[0].dtype.name(),
layouts[1].TensorShape::to_string().c_str(),
layouts[1].dtype.name(),
layouts[layouts.size() - 1].TensorShape::to_string().c_str(),
layouts[layouts.size() - 1].dtype.name(),
algo->name(),
workspace / (1024 * 1024.0), algo->is_reproducible()); workspace / (1024 * 1024.0), algo->is_reproducible());
megdnn_opr->execution_policy() = {algo}; megdnn_opr->execution_policy() = {algo};
return workspace; return workspace;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册