提交 486cbdea 编写于 作者: M Megvii Engine Team

fix(mgb/opt): nchw to nchw4 pass suppport ic less than 4

GitOrigin-RevId: a3c205f38f76c8009ea7e4a8a85d12c7dd7f93e9
上级 1c3d1f86
......@@ -10,6 +10,7 @@
*/
#include "src/cuda/convolution/opr_impl.h"
#include "megdnn/dtype.h"
#include "src/cuda/convolution/helper.h"
#include "src/cuda/convolution/backward_data/algo.h"
#include "src/cuda/convolution/backward_filter/algo.h"
......@@ -28,10 +29,35 @@ using namespace convolution;
/* ============== ConvolutionForwardImpl ============== */
ConvolutionForwardImpl::ConvBiasExtraData
ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& dst) {
ConvolutionForwardImpl::conv_bias_extra_data(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst) {
auto conv_param = param();
DType bias_type;
if (src.dtype.enumv() == DTypeEnum::QuantizedS8) {
bias_type = dtype::QuantizedS32(
src.dtype.param<dtype::QuantizedS8>().scale *
filter.dtype.param<dtype::QuantizedS8>().scale);
} else if (src.dtype.enumv() == DTypeEnum::Quantized8Asymm) {
bias_type = dtype::QuantizedS32(
src.dtype.param<dtype::Quantized8Asymm>().scale *
filter.dtype.param<dtype::Quantized8Asymm>().scale);
} else if (src.dtype.enumv() == DTypeEnum::Uint8 ||
src.dtype.enumv() == DTypeEnum::Int8) {
bias_type = dtype::Int32{};
} else if (src.dtype.enumv() == DTypeEnum::Quantized4Asymm) {
bias_type = dtype::QuantizedS32(
src.dtype.param<dtype::Quantized4Asymm>().scale *
filter.dtype.param<dtype::Quantized4Asymm>().scale);
} else {
megdnn_assert(src.dtype.category() == DTypeCategory::FLOAT);
bias_type = src.dtype;
}
ConvBiasExtraData ret = {this->handle()->create_operator<ConvBiasForward>(),
TensorLayout(dst.dtype), TensorLayout(dst.dtype)};
TensorLayout(bias_type), TensorLayout(dst.dtype)};
ret.convbias_opr->param() = {param::ConvBias::NonlineMode::IDENTITY,
conv_param.mode,
conv_param.sparse,
......@@ -54,7 +80,7 @@ ConvolutionForwardImpl::get_algorithm_heuristic(const TensorLayout& src,
const TensorLayout& dst,
size_t workspace_limit_in_bytes,
bool reproducible) {
auto extra_data = conv_bias_extra_data(dst);
auto extra_data = conv_bias_extra_data(src, filter, dst);
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get())
->get_algorithm_heuristic(src, filter, extra_data.bias_layout,
extra_data.z_layout, dst,
......@@ -65,7 +91,7 @@ std::vector<ConvolutionForwardImpl::Algorithm*>
ConvolutionForwardImpl::get_all_algorithms(const TensorLayout& src,
const TensorLayout& filter,
const TensorLayout& dst) {
auto extra_data = conv_bias_extra_data(dst);
auto extra_data = conv_bias_extra_data(src, filter, dst);
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get())
->get_all_algorithms(src, filter, extra_data.bias_layout,
extra_data.z_layout, dst);
......@@ -75,7 +101,7 @@ size_t ConvolutionForwardImpl::get_workspace_in_bytes(
const TensorLayout& src, const TensorLayout& filter,
const TensorLayout& dst,
const PreprocessedFilter* preprocessed_filter) {
auto extra_data = conv_bias_extra_data(dst);
auto extra_data = conv_bias_extra_data(src, filter, dst);
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get())
->get_workspace_in_bytes(
src, filter, extra_data.bias_layout, extra_data.z_layout,
......@@ -90,7 +116,8 @@ void ConvolutionForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_out dst,
const PreprocessedFilter* preprocessed_filter,
_megdnn_workspace workspace) {
auto extra_data = conv_bias_extra_data(dst.layout);
auto extra_data =
conv_bias_extra_data(src.layout, filter.layout, dst.layout);
TensorND bias(nullptr, extra_data.bias_layout);
TensorND z(nullptr, extra_data.z_layout);
return static_cast<ConvBiasForwardImpl*>(extra_data.convbias_opr.get())
......
......@@ -61,7 +61,9 @@ class ConvolutionForwardImpl: public ConvolutionForward {
TensorLayout z_layout;
};
private:
ConvBiasExtraData conv_bias_extra_data(const TensorLayout&);
ConvBiasExtraData conv_bias_extra_data(const TensorLayout&,
const TensorLayout&,
const TensorLayout&);
};
class ConvolutionBackwardDataImpl: public ConvolutionBackwardData {
......
......@@ -60,19 +60,24 @@ MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder,
public:
//! relayout type of this opr
enum class LayoutType {
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout
NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose
///< channel size less than 4
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4
//!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to
//!< nchw4 layout
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout
//!< to nchw4 layout whose
//! channel size less than 4
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
//!< layout
......@@ -177,11 +182,21 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[3] = inp_shape[2];
dst[4] = inp_shape[4];
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4){
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4 ||
layout_type() == RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW4_IC_SMALL_CONV) {
if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
} else {
mgb_assert(layout_type() ==
RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW4_IC_SMALL_CONV);
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4);
}
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 4;
dst[1] = (inp_shape[1] + 4 - 1) / 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
......@@ -194,11 +209,23 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_DENSE) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
WEIGHT_NCHW_TO_NCHW4_DENSE ||
layout_type() ==
RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV) {
if (layout_type() ==
RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
} else {
mgb_assert(layout_type() ==
RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV);
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4);
}
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 4;
dst[1] = (inp_shape[1] + 4 - 1) / 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
......@@ -427,6 +454,23 @@ void TensorReformatPass::translate_pass(OptState& opt) const {
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW_TO_NCHW4_IC_SMALL_CONV] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto y = opr::RelayoutFormat::make(
x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL);
return y.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto y = opr::RelayoutFormat::make(
x, megdnn::param::RelayoutFormat::Mode::
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT);
return y.node();
};
reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
......@@ -1367,29 +1411,40 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
RelayoutMode weight_to_nchw4_mode_group =
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP;
auto trans_nchw4 = [weight_to_nchw4_mode_dense,
weight_to_nchw4_mode_group](
struct ConvMode {
RelayoutMode weight;
RelayoutMode src;
};
auto trans_nchw4 =
[weight_to_nchw4_mode_dense, weight_to_nchw4_mode_group,
src_to_nchw4_mode](
const megdnn::param::Convolution::Sparse conv_mode,
const VarNode* filter) -> RelayoutMode {
const VarNode* filter) -> ConvMode {
if (conv_mode == megdnn::param::Convolution::Sparse::DENSE) {
mgb_assert(filter->shape().ndim == 4,
"The origin filter is not NCHW mode");
size_t IC = filter->shape()[1];
mgb_assert(IC % 4 == 0,
"The input channel should be divisible by 4");
return weight_to_nchw4_mode_dense;
if (IC < 4) {
return {RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV,
RelayoutMode::NCHW_TO_NCHW4_IC_SMALL_CONV};
} else {
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode};
}
} else {
mgb_assert(conv_mode == megdnn::param::Convolution::Sparse::GROUP);
mgb_assert(filter->shape().ndim == 5,
"The origin filter if not NCHW mode");
size_t IC = filter->shape()[2];
mgb_assert(IC % 4 == 0,
"The input channel should be divisible by 4");
return weight_to_nchw4_mode_group;
"The input channel should be divisible by 4 for group "
"conv");
return {weight_to_nchw4_mode_group, src_to_nchw4_mode};
}
};
auto replace_conv_opr = [trans_nchw4, conv_format, src_to_nchw4_mode](
OperatorNodeBase* opr, const VarNodeArray& new_inp) {
auto replace_conv_opr = [trans_nchw4, conv_format](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& conv_opr = opr->cast_final_safe<opr::ConvolutionForward>();
if (conv_opr.param().format !=
......@@ -1397,18 +1452,19 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
auto conv_mode =
trans_nchw4(conv_opr.param().sparse, new_inp[1]);
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
// src: NCHW --> NCWH4
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(new_inp[0],
src_to_nchw4_mode);
auto new_src =
RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
conv_src = new_src.node();
}
// weight: NCHW --> NCHW4
auto weight_mode =
trans_nchw4(conv_opr.param().sparse, new_inp[1]);
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode);
auto new_filter =
RelayoutPlaceholder::make(new_inp[1], conv_mode.weight);
conv_filter = new_filter.node();
// format: NCHW --> NCHW4
auto new_param = conv_opr.param();
......@@ -1499,8 +1555,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
};
auto replace_conv_bias_opr = [trans_nchw4, conv_bias_format,
src_to_nchw4_mode](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& conv_bias_opr = opr->cast_final_safe<opr::ConvBiasForward>();
if (conv_bias_opr.param().format !=
......@@ -1511,17 +1567,18 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// what should be converted: src, weight
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1];
auto conv_mode =
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]);
// src: NCHW --> NCHW4
if (new_inp[0]->shape().ndim !=5) {
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(new_inp[0],
src_to_nchw4_mode);
auto new_src =
RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
conv_bias_src = new_src.node();
}
// weight: NCHW --> NCHW4 or GNCHW --> GNCHW4
auto weight_mode =
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]);
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode);
auto new_filter =
RelayoutPlaceholder::make(new_inp[1], conv_mode.weight);
conv_bias_filter = new_filter.node();
// format: NCHW --> NCHW4
auto new_param = conv_bias_opr.param();
......@@ -1538,8 +1595,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// bias: NCHW --> NCHW4
VarNode* conv_bias_bias = new_inp[2];
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(new_inp[2],
src_to_nchw4_mode);
auto new_bias =
RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode);
conv_bias_bias = new_bias.node();
}
if (new_inp.size() == 3) {
......@@ -1554,8 +1611,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// z_inp: NCHW --> NCHW4
VarNode* z_inp = new_inp[3];
if (new_inp[3]->shape().ndim == 4) {
auto new_z = RelayoutPlaceholder::make(new_inp[3],
src_to_nchw4_mode);
auto new_z =
RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode);
z_inp = new_z.node();
}
auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src,
......
......@@ -2725,6 +2725,67 @@ TEST(TestGoptInference, ConvertFormatNCHW4) {
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNCHW4Ic3) {
REQUIRE_GPU(1);
HostTensorGenerator<dtype::Float32, RandomDistribution::UNIFORM> gen{
1.2f, 127 * 127};
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto mkvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name),
dtype);
};
auto mkcvar = [&](const char* name, const TensorShape& shp,
const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp))
.rename(name),
dtype);
};
auto x = mkvar("x", {2, 3, 16, 16}, dtype::QuantizedS8(2.5f));
// ConvBias test dense
opr::ConvBias::Param param_conv_bias;
param_conv_bias.pad_h = param_conv_bias.pad_w = 1;
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::DENSE;
auto w1 = mkcvar("w1", {8, 3, 3, 3}, dtype::QuantizedS8(2.5f)),
b1 = mkcvar("b1", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv1 =
opr::ConvBias::make(x, w1, b1, param_conv_bias, {},
OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
param_conv_bias.sparse = opr::ConvBias::Param::Sparse::GROUP;
auto w2 = mkcvar("w2", {2, 4, 4, 3, 3}, dtype::QuantizedS8(2.5f)),
b2 = mkcvar("b2", {1, 8, 1, 1}, dtype::QuantizedS32(6.25f));
auto conv2 =
opr::ConvBias::make(conv1, w2, b2, param_conv_bias, {},
OperatorNodeConfig{dtype::QuantizedS8{2.5f}});
auto y = opr::TypeCvt::make(conv2, dtype::Float32());
SymbolVar y_opt;
{
auto options = gopt::OptimizeForInferenceOptions{};
options.enable_nchw4();
unpack_vector(gopt::optimize_for_inference({y}, options), y_opt);
}
ASSERT_EQ(opr::ConvBias::Param::Format::NCHW4,
find_opr<opr::ConvBias>(y_opt).param().format);
graph->compile({{y_opt, {}}})
->to_json()
->writeto_fpath(output_file(
"TestGoptInference.ConvertFormatNCHW4Ic3.json"));
HostTensorND host_y_opt, host_y;
auto func = graph->compile({make_callback_copy(y, host_y),
make_callback_copy(y_opt, host_y_opt)});
func->execute();
MGB_ASSERT_TENSOR_NEAR(host_y, host_y_opt, 1e-3);
}
TEST(TestGoptInference, ConvertFormatNCHW88) {
HostTensorGenerator<> gen;
auto cn = CompNode::load("cpu0");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册