提交 777f3ea9 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

refactor(gopt): format code

GitOrigin-RevId: 9d5c87000fdfa291d91306365f7401c8af443dc1
上级 b44e0549
......@@ -10,23 +10,23 @@
* implied.
*/
#include "megbrain/gopt/inference.h"
#include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/basic_arith.h"
#include "megbrain/gopt/gtrans.h"
#include "megbrain/gopt/inference.h"
#include "megbrain/graph/event.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/utils/shared_set.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/blas.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megbrain/utils/shared_set.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/tensor_format.h"
......@@ -68,8 +68,8 @@ using namespace gopt;
MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder,
cg::SingleCNOperatorNodeBase) // {
public:
//! relayout type of this opr
enum class LayoutType {
//! relayout type of this opr
enum class LayoutType {
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
......@@ -112,25 +112,28 @@ public:
//!< NCHW44_DOT layout dense
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to
//!< NCHW44_DOT layout group
};
};
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);
/*!
/*!
* \param src_var the input var
* \param layout_type tensor layout transform type of this relayout
* placeholder as described in LayoutType
*/
static SymbolVar make(VarNode* src_var, LayoutType layout_type);
static SymbolVar make(VarNode* src_var, LayoutType layout_type);
LayoutType layout_type() const { return m_layout_type; }
LayoutType layout_type() const {
return m_layout_type;
}
private:
void init_output_static_infer_desc() override;
void scn_do_execute() override;
void init_output_comp_node() override;
const LayoutType m_layout_type;
};
void init_output_static_infer_desc() override;
void scn_do_execute() override;
void init_output_comp_node() override;
const LayoutType m_layout_type;
}
;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder);
......@@ -211,7 +214,7 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[3] = inp_shape[3];
dst[4] = 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW){
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
......@@ -249,7 +252,7 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
dst[3] = inp_shape[3];
dst[4] = inp_shape[4];
dst[5] = 4;
}else if (layout_type() ==
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0);
dst.ndim = 5;
......@@ -1033,7 +1036,6 @@ EnableTensorCorePass::make_tensorcore_converter() {
"can not be changed in this opt "
"pass");
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
auto replace_warp_affine_opr =
[replace_inps_to_nchw4, replace_non_nchw4_opr](
......@@ -1247,7 +1249,8 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
}
if (nr_shape_changed) {
auto inps = new_inp;
if (nr_shape_changed >= nr_inps / 2) { // CHWN4 > NCHW4 -> use CHWN4
if (nr_shape_changed >=
nr_inps / 2) { // CHWN4 > NCHW4 -> use CHWN4
for (size_t i = 0; i < nr_inps; ++i) {
if (varshape_changed.count(new_inp[i]) == 0) {
auto symvar = RelayoutPlaceholder::make(
......@@ -1309,7 +1312,6 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
"can not be changed in this opt "
"pass");
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
// capture by copy to avoid use after return
auto replace_warp_affine_opr =
......@@ -1410,7 +1412,7 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var,
return new_var;
}
std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
MIDOUT_B("EnableNCHW4Pass::make")
auto ret = std::make_unique<EnableNCHW4Pass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
......@@ -1469,14 +1471,12 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
auto conv_mode =
trans_nchw4(conv_opr.param().sparse, new_inp[1]);
auto conv_mode = trans_nchw4(conv_opr.param().sparse, new_inp[1]);
VarNode *conv_src = new_inp[0], *conv_filter = new_inp[1];
// src: NCHW --> NCWH4
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src =
RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
auto new_src = RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
conv_src = new_src.node();
}
// weight: NCHW --> NCHW4
......@@ -1488,8 +1488,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
new_param.format = conv_format;
// dst
auto new_conv_opr = opr::Convolution::make(
conv_src, conv_filter, new_param,
conv_opr.execution_policy(), conv_opr.config());
conv_src, conv_filter, new_param, conv_opr.execution_policy(),
conv_opr.config());
OperatorNodeBase* new_opr = new_conv_opr.node()->owner_opr();
mgb_assert(new_conv_opr.shape().ndim == 5,
"The conv dst dim is not trans to nchw4");
......@@ -1515,10 +1515,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// what should be converted: src, weight
VarNode *src = new_inp[0], *filter = new_inp[1];
// src: NCHW --> NCHW4
if (new_inp[0]->shape().ndim !=5) {
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(new_inp[0],
src_to_nchw4_mode);
auto new_src =
RelayoutPlaceholder::make(new_inp[0], src_to_nchw4_mode);
src = new_src.node();
}
// weight: BNCHW --> BNCHW4
......@@ -1531,7 +1531,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
auto new_param = batch_conv_bias_opr.param();
new_param.format = batch_conv_bias_format;
if (new_inp.size() == 2) {
auto dst = opr::BatchConvBias::make(src, filter, new_param,
auto dst = opr::BatchConvBias::make(
src, filter, new_param,
batch_conv_bias_opr.execution_policy(),
batch_conv_bias_opr.config());
OperatorNodeBase* new_opr = dst.node()->owner_opr();
......@@ -1542,12 +1543,13 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// bias: NCHW --> NCHW4
VarNode* bias = new_inp[2];
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(new_inp[2],
src_to_nchw4_mode);
auto new_bias =
RelayoutPlaceholder::make(new_inp[2], src_to_nchw4_mode);
bias = new_bias.node();
}
if (new_inp.size() == 3) {
auto dst = opr::BatchConvBias::make(src, filter, bias, new_param,
auto dst = opr::BatchConvBias::make(
src, filter, bias, new_param,
batch_conv_bias_opr.execution_policy(),
batch_conv_bias_opr.config());
OperatorNodeBase* new_opr = dst.node()->owner_opr();
......@@ -1558,12 +1560,13 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// z_inp: NCHW --> NCHW4
VarNode* z_inp = new_inp[3];
if (new_inp[3]->shape().ndim == 4) {
auto new_z = RelayoutPlaceholder::make(new_inp[3],
src_to_nchw4_mode);
auto new_z =
RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode);
z_inp = new_z.node();
}
auto dst = opr::BatchConvBias::make(src, filter, bias, z_inp,
new_param,batch_conv_bias_opr.execution_policy(),
auto dst =
opr::BatchConvBias::make(src, filter, bias, z_inp, new_param,
batch_conv_bias_opr.execution_policy(),
batch_conv_bias_opr.config());
OperatorNodeBase* new_opr = dst.node()->owner_opr();
mgb_assert(dst.shape().ndim == 5,
......@@ -1584,13 +1587,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
// what should be converted: src, weight
VarNode *conv_bias_src = new_inp[0], *conv_bias_filter = new_inp[1];
auto conv_mode =
trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]);
auto conv_mode = trans_nchw4(conv_bias_opr.param().sparse, new_inp[1]);
// src: NCHW --> NCHW4
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src =
RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
auto new_src = RelayoutPlaceholder::make(new_inp[0], conv_mode.src);
conv_bias_src = new_src.node();
}
// weight: NCHW --> NCHW4 or GNCHW --> GNCHW4
......@@ -1632,9 +1633,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
RelayoutPlaceholder::make(new_inp[3], src_to_nchw4_mode);
z_inp = new_z.node();
}
auto new_conv_bias_opr = opr::ConvBias::make(conv_bias_src,
conv_bias_filter, conv_bias_bias, z_inp, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
auto new_conv_bias_opr = opr::ConvBias::make(
conv_bias_src, conv_bias_filter, conv_bias_bias, z_inp,
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
......@@ -1654,8 +1656,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
auto temp_inp = new_inp;
for (size_t i = 0; i < opr->input().size(); i++) {
if (new_inp[i]->shape().ndim == 4) {
auto new_var = RelayoutPlaceholder::make(
new_inp[i], src_to_nchw4_mode);
auto new_var = RelayoutPlaceholder::make(new_inp[i],
src_to_nchw4_mode);
temp_inp[i] = new_var.node();
} else {
mgb_assert((new_inp[i]->shape().ndim == 5) ||
......@@ -1697,8 +1699,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
mgb_assert(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8);
auto new_param = pooling.param();
new_param.format = Format::NCHW4;
auto new_pooling =
opr::PoolingForward::make(new_inp[0], new_param, opr->config());
auto new_pooling = opr::PoolingForward::make(new_inp[0], new_param,
opr->config());
mgb_assert(new_pooling.shape().ndim == 5,
"out var of Pooling opr after transform must be 5 (got: "
"%zu).",
......@@ -1767,8 +1769,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter(){
//! supportted nchw4
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
replace_func[opr::BatchConvBias::typeinfo()] =
replace_batch_conv_bias_opr;
replace_func[opr::BatchConvBias::typeinfo()] = replace_batch_conv_bias_opr;
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
replace_func[opr::ResizeForward::typeinfo()] = replace_resize_opr;
replace_func[opr::WarpPerspectiveForward::typeinfo()] =
......@@ -1811,7 +1812,7 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
return new_var;
}
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size){
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
using RelayoutMode = RelayoutPlaceholder::LayoutType;
using TestFilterResult = std::pair<TransType, RelayoutMode>;
RelayoutMode weight_to_nchwxx_mode_dense =
......@@ -2205,7 +2206,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
size_t OC = filter->shape()[0];
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE;
ret.relayout_mod =
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
} else if (IC < pack_c_size && OC % pack_c_size == 0) {
ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX;
......@@ -2223,7 +2225,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
} else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) {
ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP;
ret.relayout_mod =
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP;
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44_DOT;
}
}
......@@ -2538,7 +2541,6 @@ public:
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4;
auto reformat = opr::RelayoutFormat::make(inp, param);
return reformat.node();
};
m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW4)] =
......@@ -2563,7 +2565,6 @@ public:
auto y0 = opr::Reshape::make(x, tshp);
auto y1 = opr::Dimshuffle::make(y0, {1, 3, 4, 0, 2});
return y1.node();
};
m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW)] =
......@@ -2593,22 +2594,27 @@ public:
MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr,
cg::SingleCNOperatorNodeBase) // {
public:
AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format,
AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);
static SymbolVar make(VarNode* inpvar, TensorFormat inp_format,
static SymbolVar make(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);
TensorFormat inp_format() const { return m_inp_format; }
TensorFormat inp_format() const {
return m_inp_format;
}
TensorFormat out_format() const { return m_out_format; }
TensorFormat out_format() const {
return m_out_format;
}
private:
void init_output_static_infer_desc() override;
void scn_do_execute() override;
const TensorFormat m_inp_format;
const TensorFormat m_out_format;
};
void init_output_static_infer_desc() override;
void scn_do_execute() override;
const TensorFormat m_inp_format;
const TensorFormat m_out_format;
}
;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr);
......@@ -2914,7 +2920,8 @@ void ShuffleShuffleRemovePass::Impl::do_replace() {
bool force_folding_typecvt = false;
bool first_shuffle = false;
// initialize inp_format and out_format
TensorFormat out_format = TensorFormat::NCHW, inp_format = out_format;
TensorFormat out_format = TensorFormat::NCHW,
inp_format = out_format;
megdnn::DType inp_dtype = cur->input(0)->dtype(),
out_dtype = cur->output(0)->dtype();
SmallVector<megdnn::DType> out_dtype_vec;
......
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册