提交 b82e8f00 编写于 作者: M Megvii Engine Team

refactor(gopt): refact the padding channel opt pass

GitOrigin-RevId: ee3f55aa66f21fe2d4a042298aafe4a0a02915f7
上级 f444d4fe
......@@ -783,7 +783,8 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_optimize_options(
});
cb(nchw64, {
add_pass<FuseConvBiasNonlinPass>();
add_pass<PaddingChannelPass>();
add_pass(PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64));
add_pass<FuseConvBiasZPass>();
add_pass(EnableNCHW64Pass::make_nchw64_converter());
add_pass<ShuffleShuffleRemovePass>();
......
......@@ -33,6 +33,54 @@ using namespace gopt;
using ReformatKey = ReformatManager::ReformatKey;
/* ==================== PaddingChannelPass ================= */
namespace {
size_t padding_int4(size_t in_channel, bool flag) {
static_cast<void>(flag);
if (in_channel <= 32) {
return (8 - (in_channel % 8)) % 8;
} else {
return (64 - (in_channel % 64)) % 64;
}
}
size_t padding_int8(size_t in_channel, bool flag) {
if (flag) {
if (in_channel <= 16) {
return (4 - (in_channel % 4)) % 4;
} else {
return (32 - (in_channel % 32)) % 32;
}
} else {
return (4 - (in_channel % 4)) % 4;
}
}
size_t padding_4(size_t in_channel, bool) {
return (4 - (in_channel % 4)) % 4;
};
} // namespace
std::unique_ptr<PaddingChannelPass> PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform layout_transform) {
MIDOUT_B("PaddingChannelPass::make")
using LayoutTrans = cg::GraphCommonOptimizeOptions::LayoutTransform;
auto ret = std::make_unique<PaddingChannelPass>();
auto& alignment_map = ret->m_alignment_map;
if (layout_transform == LayoutTrans::NCHW64) {
alignment_map[DTypeEnum::QuantizedS4] = padding_int4;
alignment_map[DTypeEnum::Quantized4Asymm] = padding_int4;
alignment_map[DTypeEnum::QuantizedS8] = padding_int8;
} else if (
layout_transform == LayoutTrans::NCHW44 ||
layout_transform == LayoutTrans::NCHW44_DOT) {
alignment_map[DTypeEnum::QuantizedS8] = padding_4;
alignment_map[DTypeEnum::Quantized8Asymm] = padding_4;
alignment_map[DTypeEnum::Float32] = padding_4;
}
ret->fill_opr_convert_fun(layout_transform);
return ret;
MIDOUT_E
}
const char* PaddingChannelPass::name() const {
return mgb_cstr_log("padding output channel to multiple of 4/32");
}
......@@ -42,267 +90,240 @@ void PaddingChannelPass::apply(OptState& opt) const {
// do not check shape
opt.set_var_replace_check_flag(
VarReplaceCheckFlag::CHECK_ALL ^ VarReplaceCheckFlag::CHECK_SHAPE);
ThinHashSet<OperatorNodeBase*> padding_oprs;
ThinHashMap<
Typeinfo*,
thin_function<OperatorNodeBase*(OperatorNodeBase*, const VarNodeArray&)>>
opr_replace_funcs;
m_padding_oprs.clear();
auto rewriter = opt.graph().make_rewriter();
auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(
inp->dtype().enumv() == DTypeEnum::QuantizedS4 ||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 ||
inp->dtype().enumv() == DTypeEnum::QuantizedS32);
TensorShape shape{
inp->shape()[0], pad_channels, inp->shape()[2], inp->shape()[3]};
std::shared_ptr<HostTensorND> host_val =
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
host_val->resize(shape);
auto ptr = host_val->raw_ptr();
size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
auto out = opr::Concat::make({inp, padding}, 1);
return out.node();
};
auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(
inp->dtype().enumv() == DTypeEnum::QuantizedS4 ||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 ||
inp->dtype().enumv() == DTypeEnum::QuantizedS32);
TensorShape shape{
pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3]};
std::shared_ptr<HostTensorND> host_val =
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
host_val->resize(shape);
auto ptr = host_val->raw_ptr();
size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
auto out = opr::Concat::make({inp, padding}, 0);
return out.node();
};
auto extract_subtensor = [](VarNode* inp,
const TensorShape& orig_shape) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(inp->shape()[0] == orig_shape[0]);
mgb_assert(inp->shape()[2] == orig_shape[2]);
mgb_assert(inp->shape()[3] == orig_shape[3]);
size_t orig_channels = orig_shape[1];
auto x = SymbolVar(inp);
auto cv = [&x](int v) { return x.make_scalar(v); };
using AIdx = opr::Subtensor::AxisIndexer;
auto sub = opr::Subtensor::make(
x, {AIdx::make_interval(0, None, None, cv(1)),
AIdx::make_interval(1, None, cv(orig_channels), None),
AIdx::make_interval(2, None, None, cv(1)),
AIdx::make_interval(3, None, None, cv(1))});
return sub.node();
};
// padding policy for conv bias with data type qint8
auto padding_policy_qint8 = [&padding_oprs, &pad_in_channels, &pad_out_channels](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 3);
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
size_t new_in_channels = new_inp[0]->shape()[1];
// pad input channels
if (padding_oprs.count(opr->input(0)->owner_opr())) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels = 0;
mgb_assert(new_in_channels == in_channels);
if (in_channels <= 16) {
if (in_channels % 4)
pad_channels = 4 - (in_channels % 4); // pad to use dp4a
} else {
if (in_channels % 32)
pad_channels = 32 - (in_channels % 32); // pad to use tensorcore
auto on_opr = [this, &opt, &rewriter](OperatorNodeBase* opr) {
auto it = m_opr_replace_funcs.find(opr->dyn_typeinfo());
if (it != m_opr_replace_funcs.end()) {
VarNodeArray new_inp;
new_inp.reserve(opr->input().size());
for (auto&& inp : opr->input()) {
new_inp.push_back(rewriter.get_var(inp));
}
if (pad_channels > 0) {
inps[0] = pad_in_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
auto new_opr = (it->second)(opr, new_inp);
auto &&out0 = opr->output(), &&out1 = new_opr->output();
mgb_assert(
out0.size() == out1.size(),
"bad opr replace: src=%s{%s} dst=%s{%s}, "
"src.size=%zu "
"dst.size=%zu",
opr->cname(), opr->dyn_typeinfo()->name, new_opr->cname(),
new_opr->dyn_typeinfo()->name, out0.size(), out1.size());
for (size_t i = 0; i < out0.size(); ++i) {
if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
mgb_assert(!out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
auto src = out0[i];
auto dst = out1[i];
if (opt.graph().endpoint_contain(src) &&
!src->shape().eq_shape(dst->shape())) {
dst = extract_subtensor(dst, src->shape());
}
rewriter.replace_var(src, dst, nullptr);
}
}
}
out_channels = inps[1]->shape()[0];
in_channels = inps[1]->shape()[1];
size_t pad_channels = 0;
if (out_channels <= 16) {
if (out_channels % 4)
pad_channels = 4 - (out_channels % 4);
} else {
if (out_channels % 32)
pad_channels = 32 - (out_channels % 32);
}
if (pad_channels > 0) {
inps[1] = pad_out_channels(inps[1], pad_channels);
inps[2] = pad_in_channels(inps[2], pad_channels);
padding_oprs.insert(opr);
rewriter.auto_replace_outputs(opr);
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
// padding policy for conv bias with data type qint4 and quint4
auto padding_policy_int4 = [&padding_oprs, &pad_in_channels, &pad_out_channels](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 3);
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
size_t new_in_channels = new_inp[0]->shape()[1];
// pad input channels
if (padding_oprs.count(opr->input(0)->owner_opr())) {
if (new_in_channels <= 32) {
if (new_in_channels % 8 == 0) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels_0 = 8 - (new_in_channels % 8);
size_t pad_channels_1 = 8 - (in_channels % 8);
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
}
} else {
if (new_in_channels % 64 == 0) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels_0 = 64 - (new_in_channels % 64);
size_t pad_channels_1 = 64 - (in_channels % 64);
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
}
}
MIDOUT_E
}
VarNode* PaddingChannelPass::extract_subtensor(
VarNode* inp, const TensorShape& orig_shape) const {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(inp->shape()[0] == orig_shape[0]);
mgb_assert(inp->shape()[2] == orig_shape[2]);
mgb_assert(inp->shape()[3] == orig_shape[3]);
size_t orig_channels = orig_shape[1];
auto x = SymbolVar(inp);
auto cv = [&x](int v) { return x.make_scalar(v); };
using AIdx = opr::Subtensor::AxisIndexer;
auto sub = opr::Subtensor::make(
x, {AIdx::make_interval(0, None, None, cv(1)),
AIdx::make_interval(1, None, cv(orig_channels), None),
AIdx::make_interval(2, None, None, cv(1)),
AIdx::make_interval(3, None, None, cv(1))});
return sub.node();
};
VarNode* PaddingChannelPass::pad_in_channels(VarNode* inp, size_t pad_channels) {
mgb_assert(inp->shape().ndim == 4);
TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2], inp->shape()[3]};
std::shared_ptr<HostTensorND> host_val =
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
host_val->resize(shape);
auto ptr = host_val->raw_ptr();
size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
auto out = opr::Concat::make({inp, padding}, 1);
return out.node();
};
VarNode* PaddingChannelPass::pad_out_channels(VarNode* inp, size_t pad_channels) {
mgb_assert(inp->shape().ndim == 4);
TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2], inp->shape()[3]};
std::shared_ptr<HostTensorND> host_val =
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
host_val->resize(shape);
auto ptr = host_val->raw_ptr();
size_t size_bytes = TensorLayout{shape, inp->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding = opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
auto out = opr::Concat::make({inp, padding}, 0);
return out.node();
};
// padding policy for conv bias with data type qint8
OperatorNodeBase* PaddingChannelPass::padding_policy(
OperatorNodeBase* opr, const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 3);
//! new weights and old weights are same shape
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
size_t new_in_channels = new_inp[0]->shape()[1];
auto it = m_alignment_map.find(opr->input(0)->dtype().enumv());
if (it != m_alignment_map.end()) {
mgb_assert(it->second);
} else {
return serialization::copy_opr_shallow(*opr, inps, opr->config());
}
// pad input channels
if (m_padding_oprs.count(opr->input(0)->owner_opr())) {
//! as the opr of input var is padding, but the dtype of input and output of
//! the input opr maybe different, so the alignment is not the same
size_t pad_channels_0 = it->second(new_in_channels, true);
size_t pad_channels_1 = it->second(in_channels, true);
if (pad_channels_0) {
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
} else {
size_t pad_channels = 0;
mgb_assert(new_in_channels == in_channels);
if (in_channels <= 32) {
if (in_channels % 8)
pad_channels = 8 - (in_channels % 8);
} else {
if (in_channels % 64)
pad_channels = 64 - (in_channels % 64);
}
if (pad_channels > 0) {
inps[0] = pad_in_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
pad_channels_1 = new_in_channels - in_channels;
}
out_channels = inps[1]->shape()[0];
in_channels = inps[1]->shape()[1];
size_t pad_channels = 0;
if (out_channels <= 32) {
if (out_channels % 8)
pad_channels = 8 - (out_channels % 8);
} else {
if (out_channels % 64)
pad_channels = 64 - (out_channels % 64);
if (pad_channels_1) {
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
}
} else {
mgb_assert(new_in_channels == in_channels);
size_t pad_channels = it->second(in_channels, true);
if (pad_channels > 0) {
inps[1] = pad_out_channels(inps[1], pad_channels);
inps[2] = pad_in_channels(inps[2], pad_channels);
padding_oprs.insert(opr);
inps[0] = pad_in_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
}
out_channels = inps[1]->shape()[0];
size_t pad_channels = it->second(out_channels, true);
if (pad_channels > 0) {
inps[1] = pad_out_channels(inps[1], pad_channels);
inps[2] = pad_in_channels(inps[2], pad_channels);
m_padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
opr_replace_funcs[opr::ConvBiasForward::typeinfo()] =
[&padding_oprs, &padding_policy_qint8, &padding_policy_int4](
OperatorNodeBase* opr, const VarNodeArray& new_inp) {
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
return padding_policy_qint8(opr, new_inp);
} else if (
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 ||
opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm) {
return padding_policy_int4(opr, new_inp);
} else {
mgb_assert(
padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"conv bias operator for data type(%s) cannot be "
"padded channel. "
"consumer(%s), producer(%s)",
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(
*opr, new_inp, opr->config());
}
};
opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] =
[&padding_oprs, &pad_in_channels, &pad_out_channels](
OperatorNodeBase* opr, const VarNodeArray& new_inp) {
if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) {
void PaddingChannelPass::fill_opr_convert_fun(LayoutTrans layout_trans) {
add_convbias_replace_func(layout_trans);
add_conv_backward_data_replace_func(layout_trans);
add_format_aware_opr_replace_func(layout_trans);
add_elemwise_like_opr_replace_func(layout_trans);
add_nonpadding_oprs_replace_func(layout_trans);
}
void PaddingChannelPass::add_convbias_replace_func(LayoutTrans layout_trans) {
if (layout_trans == LayoutTrans::NCHW64) {
m_opr_replace_funcs[opr::ConvBiasForward::typeinfo()] =
[this](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
return padding_policy(opr, new_inp);
} else if (
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4 ||
opr->input(0)->dtype().enumv() ==
DTypeEnum::Quantized4Asymm) {
return padding_policy(opr, new_inp);
} else {
mgb_assert(
m_padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"conv bias operator for data type(%s) cannot be "
"padded channel. "
"consumer(%s), producer(%s)",
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(
*opr, new_inp, opr->config());
}
};
} else if (layout_trans == LayoutTrans::NCHW44) {
m_opr_replace_funcs[opr::ConvBiasForward::typeinfo()] =
[this](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
return padding_policy(opr, new_inp);
};
}
}
void PaddingChannelPass::add_conv_backward_data_replace_func(LayoutTrans layout_trans) {
if (layout_trans == LayoutTrans::NCHW64) {
m_opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] =
[this](OperatorNodeBase* opr, const VarNodeArray& new_inp) {
if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) {
mgb_assert(
m_padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"conv bwd data operator for data type(%s) cannot "
"be "
"padded channel. "
"consumer(%s), producer(%s)",
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(
*opr, new_inp, opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(
padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"conv bwd data operator for data type(%s) cannot "
"be "
"padded channel. "
"consumer(%s), producer(%s)",
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(
*opr, new_inp, opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(
new_inp.size() == 2,
"deconv (conv bwd data) operator for inference can "
"only have 2 input vars(got:%zu)",
new_inp.size());
mgb_assert(opr->input(0)->shape().eq_shape(new_inp[0]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(0)->shape()[0];
size_t in_channels = opr->input(0)->shape()[1];
size_t new_out_channels = new_inp[1]->shape()[1];
// pad output channels
if (padding_oprs.count(opr->input(1)->owner_opr())) {
size_t pad_channels = new_out_channels - out_channels;
inps[0] = pad_out_channels(new_inp[0], pad_channels);
} else {
size_t pad_channels = 0;
if (out_channels % 4)
pad_channels = 4 - (out_channels % 4);
if (pad_channels > 0) {
new_inp.size() == 2,
"deconv (conv bwd data) operator for inference can "
"only have 2 input vars(got:%zu)",
new_inp.size());
mgb_assert(opr->input(0)->shape().eq_shape(new_inp[0]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(0)->shape()[0];
size_t in_channels = opr->input(0)->shape()[1];
size_t new_out_channels = new_inp[1]->shape()[1];
auto it = m_alignment_map.find(opr->input(1)->dtype().enumv());
// pad output channels
if (m_padding_oprs.count(opr->input(1)->owner_opr())) {
size_t pad_channels = new_out_channels - out_channels;
inps[0] = pad_out_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels = it->second(out_channels, false);
if (pad_channels > 0) {
inps[0] = pad_out_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
}
}
out_channels = inps[0]->shape()[0];
in_channels = inps[0]->shape()[1];
// pad input channels
size_t pad_channels = 0;
if (in_channels % 4)
pad_channels = 4 - (in_channels % 4);
if (pad_channels > 0) {
inps[0] = pad_in_channels(inps[0], pad_channels);
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
auto replace_format_aware_opr = [&padding_oprs](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
out_channels = inps[0]->shape()[0];
// pad input channels
size_t pad_channels = it->second(in_channels, false);
if (pad_channels > 0) {
inps[0] = pad_in_channels(inps[0], pad_channels);
m_padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
}
}
void PaddingChannelPass::add_format_aware_opr_replace_func(LayoutTrans) {
auto replace_format_aware_opr = [this](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 &&
opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 &&
opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) {
mgb_assert(
padding_oprs.count(opr->input(0)->owner_opr()) == 0,
m_padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"operator(type:%s,name:%s) for data type(%s) cannot be "
"padded channel. extra info:"
"consumer(%s), producer(%s)",
......@@ -312,18 +333,19 @@ void PaddingChannelPass::apply(OptState& opt) const {
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
if (padding_oprs.count(opr->input(0)->owner_opr())) {
padding_oprs.insert(opr);
if (m_padding_oprs.count(opr->input(0)->owner_opr())) {
m_padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
opr_replace_funcs[opr::PoolingForward::typeinfo()] = replace_format_aware_opr;
opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] =
m_opr_replace_funcs[opr::PoolingForward::typeinfo()] = replace_format_aware_opr;
m_opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] =
replace_format_aware_opr;
}
auto replace_elemwise_like_opr = [&padding_oprs, &extract_subtensor](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
void PaddingChannelPass::add_elemwise_like_opr_replace_func(LayoutTrans) {
auto replace_elemwise_like_opr = [this](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
bool have_padding_inp = false;
bool padding_all_inps = true;
......@@ -331,7 +353,7 @@ void PaddingChannelPass::apply(OptState& opt) const {
size_t channels_after_padding = 0;
size_t i = 0;
for (auto&& cur_inp : opr->input()) {
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0;
bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
if (!have_padding_inp)
have_padding_inp = true;
......@@ -349,7 +371,7 @@ void PaddingChannelPass::apply(OptState& opt) const {
auto inps = new_inp;
for (size_t i = 0; i < new_inp.size(); ++i) {
auto cur_inp = opr->input(i);
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0;
bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
inps[i] = extract_subtensor(inps[i], cur_inp->shape());
}
......@@ -357,72 +379,34 @@ void PaddingChannelPass::apply(OptState& opt) const {
return serialization::copy_opr_shallow(*opr, inps, opr->config());
}
if (padding_all_inps) {
padding_oprs.insert(opr);
m_padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_like_opr;
opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;
m_opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] = replace_elemwise_like_opr;
m_opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
m_opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;
}
auto replace_nonpadding_oprs = [&padding_oprs, &extract_subtensor](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
void PaddingChannelPass::add_nonpadding_oprs_replace_func(LayoutTrans) {
auto replace_nonpadding_oprs = [this](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto inps = new_inp;
for (size_t i = 0; i < new_inp.size(); ++i) {
auto cur_inp = opr->input(i);
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0;
bool padding_cur_inp = m_padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
inps[i] = extract_subtensor(inps[i], cur_inp->shape());
}
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs;
auto on_opr = [&opt, &rewriter, &opr_replace_funcs,
&extract_subtensor](OperatorNodeBase* opr) {
auto it = opr_replace_funcs.find(opr->dyn_typeinfo());
if (it != opr_replace_funcs.end()) {
VarNodeArray new_inp;
new_inp.reserve(opr->input().size());
for (auto&& inp : opr->input()) {
new_inp.push_back(rewriter.get_var(inp));
}
auto new_opr = (it->second)(opr, new_inp);
auto &&out0 = opr->output(), &&out1 = new_opr->output();
mgb_assert(
out0.size() == out1.size(),
"bad opr replace: src=%s{%s} dst=%s{%s}, "
"src.size=%zu "
"dst.size=%zu",
opr->cname(), opr->dyn_typeinfo()->name, new_opr->cname(),
new_opr->dyn_typeinfo()->name, out0.size(), out1.size());
for (size_t i = 0; i < out0.size(); ++i) {
if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
mgb_assert(!out1[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT));
auto src = out0[i];
auto dst = out1[i];
if (opt.graph().endpoint_contain(src) &&
!src->shape().eq_shape(dst->shape())) {
dst = extract_subtensor(dst, src->shape());
}
rewriter.replace_var(src, dst, nullptr);
}
}
} else {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
MIDOUT_E
m_opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs;
m_opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs;
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -509,8 +509,38 @@ public:
*/
class PaddingChannelPass final : public Pass {
public:
using ChannelAlignmentMap =
ThinHashMap<DTypeEnum, std::function<size_t(size_t, bool)>>;
using LayoutTrans = cg::GraphCommonOptimizeOptions::LayoutTransform;
const char* name() const override;
void apply(OptState& opt) const override;
void fill_opr_convert_fun(LayoutTrans layout_trans);
using ReplaceFuncs = ThinHashMap<
Typeinfo*,
thin_function<OperatorNodeBase*(OperatorNodeBase*, const VarNodeArray&)>>;
//! make channel padding opt pass with given tensor format
static std::unique_ptr<PaddingChannelPass> make(LayoutTrans layout_transform);
private:
VarNode* extract_subtensor(VarNode* inp, const TensorShape& orig_shape) const;
VarNode* pad_in_channels(VarNode* inp, size_t pad_channels);
VarNode* pad_out_channels(VarNode* inp, size_t pad_channels);
OperatorNodeBase* padding_policy(
OperatorNodeBase* opr, const VarNodeArray& new_inp);
void add_convbias_replace_func(LayoutTrans layout_transform);
void add_conv_backward_data_replace_func(LayoutTrans layout_transform);
void add_format_aware_opr_replace_func(LayoutTrans layout_transform);
void add_elemwise_like_opr_replace_func(LayoutTrans layout_transform);
void add_nonpadding_oprs_replace_func(LayoutTrans layout_transform);
ChannelAlignmentMap m_alignment_map;
ReplaceFuncs m_opr_replace_funcs;
mutable ThinHashSet<OperatorNodeBase*> m_padding_oprs;
};
/*!
......
#include "megbrain/graph/cg.h"
#include "megbrain/opr/dnn/local.h"
#include "megbrain/test/helper.h"
......@@ -5037,7 +5038,8 @@ TEST(TestGoptInference, PaddingChannels) {
SymbolVar y3_pad;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass<gopt::PaddingChannelPass>()
.add_pass(gopt::PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64))
.apply({{y3}})
.endpoint_vars(),
y3_pad);
......@@ -5101,7 +5103,8 @@ TEST(TestGoptInference, ConcatAfterPaddingChannels) {
SymbolVar y2_pad;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass<gopt::PaddingChannelPass>()
.add_pass(gopt::PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64))
.apply({{y2}})
.endpoint_vars(),
y2_pad);
......@@ -5166,7 +5169,8 @@ TEST(TestGoptInference, PaddingChannelsWithPooling) {
SymbolVar y1_pad;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass<gopt::PaddingChannelPass>()
.add_pass(gopt::PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64))
.apply({{y1}})
.endpoint_vars(),
y1_pad);
......@@ -5232,7 +5236,8 @@ TEST(TestGoptInference, PaddingChannelsWithWarpPerspective) {
SymbolVar y1_pad;
unpack_vector(
gopt::GraphOptimizer{}
.add_pass<gopt::PaddingChannelPass>()
.add_pass(gopt::PaddingChannelPass::make(
cg::GraphCommonOptimizeOptions::LayoutTransform::NCHW64))
.apply({{y1}})
.endpoint_vars(),
y1_pad);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册