提交 f60ab501 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

fix(mgb/opt): nchw to nchw4 pass suppport ic less than 4

GitOrigin-RevId: a3c205f38f76c8009ea7e4a8a85d12c7dd7f93e9
上级 8ec09922
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
*/ */
#include "src/cuda/convolution/opr_impl.h" #include "src/cuda/convolution/opr_impl.h"
#include "megdnn/dtype.h"
#include "src/cuda/convolution/helper.h" #include "src/cuda/convolution/helper.h"
#include "src/cuda/convolution/backward_data/algo.h" #include "src/cuda/convolution/backward_data/algo.h"
#include "src/cuda/convolution/backward_filter/algo.h" #include "src/cuda/convolution/backward_filter/algo.h"
...@@ -28,10 +29,35 @@ using namespace convolution; ...@@ -28,10 +29,35 @@ using namespace convolution;
/* ============== ConvolutionForwardImpl ============== */ /* ============== ConvolutionForwardImpl ============== */
ConvolutionForwardImpl::ConvBiasExtraData ConvolutionForwardImpl::ConvBiasExtraData
ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& dst) { ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst) {
auto conv_param = param(); auto conv_param = param();
DType bias_type;
if (src.dtype.enumv() == DTypeEnum::QuantizedS8) {
bias_type = dtype::QuantizedS32(
src.dtype.param<dtype::QuantizedS8>().scale *
filter.dtype.param<dtype::QuantizedS8>().scale);
} else if (src.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
bias_type = dtype::QuantizedS32(
src.dtype.param<dtype::Quantized8Asymm>().scale *
filter.dtype.param<dtype::Quantized8Asymm>().scale);
} else if (src.dtype.enumv() == DTypeEnum::Uint8 ||
src.dtype.enumv() == DTypeEnum::Int8) {
bias_type = dtype::Int32{};
} else if (src.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
bias_type = dtype::QuantizedS32(
src.dtype.param<dtype::Quantized4Asymm>().scale *
filter.dtype.param<dtype::Quantized4Asymm>().scale);
} else {
megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT);
bias_type = src.dtype;
}
ConvBiasExtraData ret = {this->handle()->create_operator<ConvBiasForward>(), ConvBiasExtraData ret = {this->handle()->create_operator<ConvBiasForward>(),
TensorLayout(dst.dtype), TensorLayout(dst.dtype)}; TensorLayout(bias_type), TensorLayout(dst.dtype)};
ret.convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY, ret.convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY,
conv_param.mode, conv_param.mode,
conv_param.sparse, conv_param.sparse,
...@@ -54,7 +80,7 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src, ...@@ -54,7 +80,7 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& dst, const TensorLayout& dst,
size_t workspace_limit_in_bytes, size_t workspace_limit_in_bytes,
bool reproducible) { bool reproducible) {
auto extra_data = conv_bias_extra_data(dst); auto extra_data = conv_bias_extra_data(src, filter, dst);
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get())
->get_algorithm_heuristic(src, filter, extra_data.bias_layout, ->get_algorithm_heuristic(src, filter, extra_data.bias_layout,
extra_data.z_layout, dst, extra_data.z_layout, dst,
...@@ -65,7 +91,7 @@ std::vector<ConvolutionForwardImpl::Algorithm*> ...@@ -65,7 +91,7 @@ std::vector<ConvolutionForwardImpl::Algorithm*>
ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src, ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src,
const TensorLayout& filter, const TensorLayout& filter,
const TensorLayout& dst) { const TensorLayout& dst) {
auto extra_data = conv_bias_extra_data(dst); auto extra_data = conv_bias_extra_data(src, filter, dst);
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get())
->get_all_algorithms(src, filter, extra_data.bias_layout, ->get_all_algorithms(src, filter, extra_data.bias_layout,
extra_data.z_layout, dst); extra_data.z_layout, dst);
...@@ -75,7 +101,7 @@ size_t ConvolutionForwardImpl::get_workspace_in_bytes( ...@@ -75,7 +101,7 @@ size_t ConvolutionForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter, const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst, const TensorLayout& dst,
const PreprocessedFilter* preprocessed_filter) { const PreprocessedFilter* preprocessed_filter) {
auto extra_data = conv_bias_extra_data(dst); auto extra_data = conv_bias_extra_data(src, filter, dst);
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get())
->get_workspace_in_bytes( ->get_workspace_in_bytes(
src, filter, extra_data.bias_layout, extra_data.z_layout, src, filter, extra_data.bias_layout, extra_data.z_layout,
...@@ -90,7 +116,8 @@ void ConvolutionForwardImpl::exec(_megdnn_tensor_in src, ...@@ -90,7 +116,8 @@ void ConvolutionForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_out dst, _megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter, const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) { _megdnn_workspace workspace) {
auto extra_data = conv_bias_extra_data(dst.layout); auto extra_data =
conv_bias_extra_data(src.layout, filter.layout, dst.layout);
TensorND bias(nullptr, extra_data.bias_layout); TensorND bias(nullptr, extra_data.bias_layout);
TensorND z(nullptr, extra_data.z_layout); TensorND z(nullptr, extra_data.z_layout);
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get()) return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get())
......
...@@ -61,7 +61,9 @@ class ConvolutionForwardImpl: public ConvolutionForward { ...@@ -61,7 +61,9 @@ class ConvolutionForwardImpl: public ConvolutionForward {
TensorLayout z_layout; TensorLayout z_layout;
}; };
private: private:
ConvBiasExtraData conv_bias_extra_data(const TensorLayout&); ConvBiasExtraData conv_bias_extra_data(const TensorLayout&,
const TensorLayout&,
const TensorLayout&);
}; };
class ConvolutionBackwardDataImpl: public ConvolutionBackwardData { class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
......
...@@ -60,19 +60,24 @@ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder, ...@@ -60,19 +60,24 @@ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder,
public: public:
//! relayout type of this opr //! relayout type of this opr
enum class LayoutType { enum class LayoutType {
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout ///< channel size less than 4
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4 WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4
//!< layout //!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to
//!< nchw4 layout //!< nchw4 layout
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout
//!< to nchw4 layout whose
//! channel size less than 4
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88 WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
//!< layout //!< layout
...@@ -177,11 +182,21 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { ...@@ -177,11 +182,21 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[3] = inp_shape[2]; dst[3] = inp_shape[2];
dst[4] = inp_shape[4]; dst[4] = inp_shape[4];
} else if (layout_type() == } else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4){ RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4 ||
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); layout_type() == RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW4_IC_SMALL_CONV) {
if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
} else {
mgb_assert(layout_type() ==
RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW4_IC_SMALL_CONV);
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4);
}
dst.ndim = 5; dst.ndim = 5;
dst[0] = inp_shape[0]; dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 4; dst[1] = (inp_shape[1] + 4 - 1) / 4;
dst[2] = inp_shape[2]; dst[2] = inp_shape[2];
dst[3] = inp_shape[3]; dst[3] = inp_shape[3];
dst[4] = 4; dst[4] = 4;
...@@ -194,11 +209,23 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() { ...@@ -194,11 +209,23 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[2] = inp_shape[2]; dst[2] = inp_shape[2];
dst[3] = inp_shape[3]; dst[3] = inp_shape[3];
} else if (layout_type() == RelayoutPlaceholder::LayoutType:: } else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_DENSE) { WEIGHT_NCHW_TO_NCHW4_DENSE ||
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0); layout_type() ==
RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV) {
if (layout_type() ==
RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
} else {
mgb_assert(layout_type() ==
RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV);
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4);
}
dst.ndim = 5; dst.ndim = 5;
dst[0] = inp_shape[0]; dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 4; dst[1] = (inp_shape[1] + 4 - 1) / 4;
dst[2] = inp_shape[2]; dst[2] = inp_shape[2];
dst[3] = inp_shape[3]; dst[3] = inp_shape[3];
dst[4] = 4; dst[4] = 4;
...@@ -427,6 +454,23 @@ void TensorReformatPass::translate_pass(OptState& opt) const { ...@@ -427,6 +454,23 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto y2 = opr::Reshape::make(y1, tshp1); auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node(); return y2.node();
}; };
reformat[LayoutType::NCHW_TO_NCHW4_IC_SMALL_CONV] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto y = opr::RelayoutFormat::make(
x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL);
return y.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto y = opr::RelayoutFormat::make(
x, megdnn::param::RelayoutFormat::Mode::
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT);
return y.node();
};
reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* { reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp); auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x); auto xshp = opr::GetVarShape::make(x);
...@@ -1367,29 +1411,40 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1367,29 +1411,40 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
RelayoutMode weight_to_nchw4_mode_group = RelayoutMode weight_to_nchw4_mode_group =
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP; RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP;
auto trans_nchw4 = [weight_to_nchw4_mode_dense, struct ConvMode {
weight_to_nchw4_mode_group]( RelayoutMode weight;
RelayoutMode src;
};
auto trans_nchw4 =
[weight_to_nchw4_mode_dense, weight_to_nchw4_mode_group,
src_to_nchw4_mode](
const megdnn::param::Convolution::Sparse conv_mode, const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter) -> RelayoutMode { const VarNode* filter) -> ConvMode {
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) { if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
mgb_assert(filter->shape().ndim == 4, mgb_assert(filter->shape().ndim == 4,
"The origin filter is not NCHW mode"); "The origin filter is not NCHW mode");
size_t IC = filter->shape()[1]; size_t IC = filter->shape()[1];
mgb_assert(IC % 4 == 0, if (IC < 4) {
"The input channel should be divisible by 4"); return {RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV,
return weight_to_nchw4_mode_dense; RelayoutMode::NCHW_TO_NCHW4_IC_SMALL_CONV};
} else {
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode};
}
} else { } else {
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP); mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
mgb_assert(filter->shape().ndim == 5, mgb_assert(filter->shape().ndim == 5,
"The origin filter if not NCHW mode"); "The origin filter if not NCHW mode");
size_t IC = filter->shape()[2]; size_t IC = filter->shape()[2];
mgb_assert(IC % 4 == 0, mgb_assert(IC % 4 == 0,
"The input channel should be divisible by 4"); "The input channel should be divisible by 4 for group "
return weight_to_nchw4_mode_group; "conv");
return {weight_to_nchw4_mode_group, src_to_nchw4_mode};
} }
}; };
auto replace_conv_opr = [trans_nchw4, conv_format, src_to_nchw4_mode]( auto replace_conv_opr = [trans_nchw4, conv_format](
OperatorNodeBase* opr, const VarNodeArray& new_inp) { OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>(); auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
if (conv_opr.param().format != if (conv_opr.param().format !=
...@@ -1397,18 +1452,19 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1397,18 +1452,19 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
return serialization::copy_opr_shallow(*opr, new_inp, return serialization::copy_opr_shallow(*opr, new_inp,
opr->config()); opr->config());
} }
auto conv_mode =
trans_nchw4(conv_opr.param().sparse, new_inp[1]);
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1]; VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
// src: NCHW --> NCWH4 // src: NCHW --> NCWH4
if (new_inp[0]->shape().ndim != 5) { if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4); mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(new_inp[0], auto new_src =
src_to_nchw4_mode); RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
conv_src = new_src.node(); conv_src = new_src.node();
} }
// weight: NCHW --> NCHW4 // weight: NCHW --> NCHW4
auto weight_mode = auto new_filter =
trans_nchw4(conv_opr.param().sparse, new_inp[1]); RelayoutPlaceholder::make(new_inp[1], conv_mode.weight);
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode);
conv_filter = new_filter.node(); conv_filter = new_filter.node();
// format: NCHW --> NCHW4 // format: NCHW --> NCHW4
auto new_param = conv_opr.param(); auto new_param = conv_opr.param();
...@@ -1499,8 +1555,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1499,8 +1555,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
}; };
auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format, auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format,
src_to_nchw4_mode]( src_to_nchw4_mode](
OperatorNodeBase* opr, OperatorNodeBase* opr,
const VarNodeArray& new_inp) { const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size()); mgb_assert(opr->input().size() == new_inp.size());
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>(); auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
if (conv_bias_opr.param().format != if (conv_bias_opr.param().format !=
...@@ -1511,17 +1567,18 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1511,17 +1567,18 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// what should be converted: src, weight // what should be converted: src, weight
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1]; VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1];
auto conv_mode =
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]);
// src: NCHW --> NCHW4 // src: NCHW --> NCHW4
if (new_inp[0]->shape().ndim !=5) { if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4); mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(new_inp[0], auto new_src =
src_to_nchw4_mode); RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
conv_bias_src = new_src.node(); conv_bias_src = new_src.node();
} }
// weight: NCHW --> NCHW4 or GNCHW --> GNCHW4 // weight: NCHW --> NCHW4 or GNCHW --> GNCHW4
auto weight_mode = auto new_filter =
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]); RelayoutPlaceholder::make(new_inp[1], conv_mode.weight);
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode);
conv_bias_filter = new_filter.node(); conv_bias_filter = new_filter.node();
// format: NCHW --> NCHW4 // format: NCHW --> NCHW4
auto new_param = conv_bias_opr.param(); auto new_param = conv_bias_opr.param();
...@@ -1538,8 +1595,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1538,8 +1595,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// bias: NCHW --> NCHW4 // bias: NCHW --> NCHW4
VarNode* conv_bias_bias = new_inp[2]; VarNode* conv_bias_bias = new_inp[2];
if (new_inp[2]->shape().ndim == 4) { if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(new_inp[2], auto new_bias =
src_to_nchw4_mode); RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode);
conv_bias_bias = new_bias.node(); conv_bias_bias = new_bias.node();
} }
if (new_inp.size() == 3) { if (new_inp.size() == 3) {
...@@ -1554,8 +1611,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){ ...@@ -1554,8 +1611,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// z_inp: NCHW --> NCHW4 // z_inp: NCHW --> NCHW4
VarNode* z_inp = new_inp[3]; VarNode* z_inp = new_inp[3];
if (new_inp[3]->shape().ndim == 4) { if (new_inp[3]->shape().ndim == 4) {
auto new_z = RelayoutPlaceholder::make(new_inp[3], auto new_z =
src_to_nchw4_mode); RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode);
z_inp = new_z.node(); z_inp = new_z.node();
} }
auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src, auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src,
......
...@@ -2725,6 +2725,67 @@ TEST(TestGoptInference, ConvertFormatNCHW4) { ...@@ -2725,6 +2725,67 @@ TEST(TestGoptInference, ConvertFormatNCHW4) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3); MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
} }
TEST(TestGoptInference, ConvertFormatNCHW4Ic3) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{
1.2f, 127 * 127};
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name),
dtype);
};
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp))
.rename(name),
dtype);
};
auto x = mkvar("x", {2, 3, 16, 16}, dtype::QuantizedS8(2.5f));
// ConvBias test dense
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
auto w1 = mkcvar("w1", {8, 3, 3, 3}, dtype::QuantizedS8(2.5f)),
b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv1 =
opr::ConvBias::make(x, w1, b1, param_conv_bias, {},
OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv2 =
opr::ConvBias::make(conv1, w2, b2, param_conv_bias, {},
OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
auto y = opr::TypeCvt::make(conv2, dtype::Float32());
SymbolVar y_opt;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4,
find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.ConvertFormatNCHW4Ic3.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNCHW88) { TEST(TestGoptInference, ConvertFormatNCHW88) {
HostTensorGenerator<> gen; HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0"); auto cn = CompNode::load("cpu0");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册