提交 8a3eb05a 编写于 作者: M Megvii Engine Team

refactor(mgb/gopt): refactor tensor reformat opt pass

GitOrigin-RevId: a1b1e89b76e4fbdca4f481156bb8af6cae8fe4d8
上级 c33126ab
......@@ -120,10 +120,6 @@ Dimension Dimension::operator/(const Dimension& rhs) const {
static_cast<char>(m_name), static_cast<char>(rhs.m_name));
if (operator==(rhs))
return Dimension(m_name, 1, 1);
megdnn_assert(
!(*this < rhs),
"Divisor must be smaller than dividend(dividend:%s, divisor:%s)",
to_string().c_str(), rhs.to_string().c_str());
if (m_stride == rhs.m_stride) {
if (m_extent == UNDETERMINED_EXTENT) {
megdnn_assert(rhs.m_extent != UNDETERMINED_EXTENT,
......
/**
* \file src/gopt/impl/folding_conv_dimshuffle.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megdnn/opr_param_defs.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
#include "megbrain/gopt/reformat_manager.h"
#if CUDA_VERSION >= 10020
MIDOUT_DECL(megbrain_folding_conv_dimshuffle)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_folding_conv_dimshuffle, \
midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using namespace mgb;
using namespace gopt;
using ReformatKey = ReformatManager::ReformatKey;
/* ==================== FoldingConvBiasDimshufflePass ================= */
const char* FoldingConvBiasDimshufflePass::name() const {
return mgb_cstr_log("folding conv bias dimshuffle pass");
}
void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
MIDOUT_B("FoldingConvBiasDimshufflePass::apply");
using DepType = cg::OperatorNodeProp::DepType;
ThinHashMap<OperatorNodeBase*,
SmallVector<std::pair<OperatorNodeBase*, DepType>>>
readers;
static const ThinHashSet<Typeinfo*> opr_type_list = {
opr::TypeCvt::typeinfo(), opr::Dimshuffle::typeinfo(),
opr::Reshape::typeinfo(), opr::ConvBias::typeinfo()};
opt.graph().iter([&readers](OperatorNodeBase* opr) {
for (auto&& i : opr->node_prop().dep_map()) {
if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) {
readers[i.first->owner_opr()].emplace_back(opr, i.second);
}
}
});
auto rewriter = opt.graph().make_rewriter();
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers](
OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check typecvt
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr);
if (typecvt == nullptr)
return false;
auto inp_dtype = typecvt->input(0)->dtype(),
out_dtype = typecvt->output(0)->dtype();
bool is_s82f32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype.enumv() == DTypeEnum::Float32;
if (!is_s82f32)
return false;
opr_set.insert(opr);
// check reshape
auto reshape =
try_cast_as_op<opr::Reshape>(typecvt->input(0)->owner_opr());
if (reshape == nullptr)
return false;
opr_set.insert(reshape);
for (auto&& i : readers[reshape]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check shuffle
auto shuffle =
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 5)
return false;
bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 &&
param.pattern[2] == 4 && param.pattern[3] == 2 &&
param.pattern[4] == 3 &&
shuffle->input(0)->shape()[4] == 4;
if (!is_nchw42nchw)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW4;
if (!is_s8nchw4)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = ReformatManager::instance().get(ReformatKey{
TensorFormats::NCHWc4, TensorFormats::NCHW})({bias});
new_bias = opr::TypeCvt::make(new_bias, dtype::Float32()).node();
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{dtype::Float32()});
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + typecvt + "
"dimshuffle + "
"reshape to conv_bias(NCHW4_NCHW)"));
return true;
};
auto try_conv_reformat_nchw42nchw32 = [&rewriter,
&readers](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check reshape
auto reshape1 = try_cast_as_op<opr::Reshape>(opr);
if (reshape1 == nullptr)
return false;
opr_set.insert(opr);
// check dimshuffle
auto shuffle = try_cast_as_op<opr::Dimshuffle>(
reshape1->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 6)
return false;
bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
param.pattern[2] == 3 && param.pattern[3] == 4 &&
param.pattern[4] == 2 && param.pattern[5] == 5 &&
shuffle->output(0)->shape()[5] == 4 &&
shuffle->output(0)->shape()[4] == 8;
if (!is_nchw42nchw32)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check reshape
auto reshape2 =
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr());
if (reshape2 == nullptr)
return false;
opr_set.insert(reshape2);
for (auto&& i : readers[reshape2]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW4;
if (!is_s8nchw4)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = ReformatManager::instance().get(ReformatKey{
TensorFormats::NCHWc4, TensorFormats::NCHWc32})({bias});
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW32;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
conv_bias->config());
rewriter.replace_var(
opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW4_NCHW32)"));
return true;
};
auto try_conv_reformat_nchw42nhwc = [&rewriter,
&readers](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check reshape
auto reshape = try_cast_as_op<opr::Reshape>(opr);
if (reshape == nullptr)
return false;
opr_set.insert(opr);
// check dimshuffle
auto shuffle =
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 5)
return false;
bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 &&
param.pattern[2] == 3 && param.pattern[3] == 1 &&
param.pattern[4] == 4 &&
shuffle->output(0)->shape()[4] == 4;
if (!is_nchw42nhwc)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
auto typecvt =
try_cast_as_op<opr::TypeCvt>(shuffle->input(0)->owner_opr());
if (typecvt == nullptr)
return false;
auto in_dtype = typecvt->input(0)->dtype(),
out_dtype = typecvt->output(0)->dtype();
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 &&
(out_dtype.enumv() == DTypeEnum::QuantizedS4 ||
out_dtype.enumv() == DTypeEnum::Quantized4Asymm);
if (!is_s82s4)
return false;
opr_set.insert(typecvt);
for (auto&& i : readers[typecvt]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW4;
if (!is_s8nchw4)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = ReformatManager::instance().get(ReformatKey{
TensorFormats::NCHWc4, TensorFormats::NHWC})({bias});
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype});
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW4_NHWC)"));
return true;
};
auto try_conv_reformat_nchw322nchw4 = [&rewriter,
&readers](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check reshape
auto reshape1 = try_cast_as_op<opr::Reshape>(opr);
if (reshape1 == nullptr)
return false;
opr_set.insert(opr);
// check dimshuffle
auto shuffle = try_cast_as_op<opr::Dimshuffle>(
reshape1->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 6)
return false;
bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
param.pattern[2] == 4 && param.pattern[3] == 2 &&
param.pattern[4] == 3 && param.pattern[5] == 5 &&
shuffle->input(0)->shape()[5] == 4 &&
shuffle->input(0)->shape()[4] == 8;
if (!is_nchw322nchw4)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check reshape
auto reshape2 =
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr());
if (reshape2 == nullptr)
return false;
opr_set.insert(reshape2);
for (auto&& i : readers[reshape2]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW32;
if (!is_s8nchw32)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = ReformatManager::instance().get(ReformatKey{
TensorFormats::NCHWc32, TensorFormats::NCHWc4})({bias});
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW32_NCHW4;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
conv_bias->config());
rewriter.replace_var(
opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW32_NCHW4)"));
return true;
};
MGB_MARK_USED_VAR(try_conv_reformat_nchw322nchw4);
MGB_MARK_USED_VAR(try_conv_reformat_nchw42nchw32);
auto on_opr = [&try_conv_dimshuffle_reshape_typecvt,
&try_conv_reformat_nchw42nchw32,
&try_conv_reformat_nchw42nhwc,
&try_conv_reformat_nchw322nchw4,
&rewriter](OperatorNodeBase* opr) {
if (!try_conv_dimshuffle_reshape_typecvt(opr) &&
!try_conv_reformat_nchw42nchw32(opr) &&
!try_conv_reformat_nchw42nhwc(opr) &&
!try_conv_reformat_nchw322nchw4(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
MIDOUT_E
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
/**
* \file src/gopt/impl/padding_channel.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/misc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/tensor_format.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/gopt/misc.h"
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
#include "megbrain/gopt/reformat_manager.h"
MIDOUT_DECL(megbrain_padding_channel)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_padding_channel, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using namespace mgb;
using namespace gopt;
using ReformatKey = ReformatManager::ReformatKey;
/* ==================== PaddingChannelPass ================= */
const char* PaddingChannelPass::name() const {
return mgb_cstr_log("padding output channel to multiple of 4/32");
}
void PaddingChannelPass::apply(OptState& opt) const {
MIDOUT_B("PaddingChannelPass::apply");
// do not check shape
opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
VarReplaceCheckFlag::CHECK_SHAPE);
ThinHashSet<OperatorNodeBase*> padding_oprs;
ThinHashMap<Typeinfo*, thin_function<OperatorNodeBase*(
OperatorNodeBase*, const VarNodeArray&)>>
opr_replace_funcs;
auto rewriter = opt.graph().make_rewriter();
auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 ||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 ||
inp->dtype().enumv() == DTypeEnum::QuantizedS32);
TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2],
inp->shape()[3]};
std::shared_ptr<HostTensorND> host_val =
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
host_val->resize(shape);
auto ptr = host_val->raw_ptr();
size_t size_bytes =
TensorLayout{shape, inp->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding =
opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
auto out = opr::Concat::make({inp, padding}, 1);
return out.node();
};
auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 ||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 ||
inp->dtype().enumv() == DTypeEnum::QuantizedS32);
TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2],
inp->shape()[3]};
std::shared_ptr<HostTensorND> host_val =
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
host_val->resize(shape);
auto ptr = host_val->raw_ptr();
size_t size_bytes =
TensorLayout{shape, inp->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding =
opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
auto out = opr::Concat::make({inp, padding}, 0);
return out.node();
};
auto extract_subtensor = [](VarNode* inp,
const TensorShape& orig_shape) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(inp->shape()[0] == orig_shape[0]);
mgb_assert(inp->shape()[2] == orig_shape[2]);
mgb_assert(inp->shape()[3] == orig_shape[3]);
size_t orig_channels = orig_shape[1];
auto x = SymbolVar(inp);
auto cv = [&x](int v) { return x.make_scalar(v); };
using AIdx = opr::Subtensor::AxisIndexer;
auto sub = opr::Subtensor::make(
x, {AIdx::make_interval(0, None, None, cv(1)),
AIdx::make_interval(1, None, cv(orig_channels), None),
AIdx::make_interval(2, None, None, cv(1)),
AIdx::make_interval(3, None, None, cv(1))});
return sub.node();
};
// padding policy for conv bias with data type qint8
auto padding_policy_qint8 = [&padding_oprs, &pad_in_channels,
&pad_out_channels](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 3);
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
size_t new_in_channels = new_inp[0]->shape()[1];
// pad input channels
if (padding_oprs.count(opr->input(0)->owner_opr())) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels = 0;
mgb_assert(new_in_channels == in_channels);
if (in_channels <= 16) {
if (in_channels % 4)
pad_channels = 4 - (in_channels % 4); // pad to use dp4a
} else {
if (in_channels % 32)
pad_channels =
32 - (in_channels % 32); // pad to use tensorcore
}
if (pad_channels > 0) {
inps[0] = pad_in_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
}
out_channels = inps[1]->shape()[0];
in_channels = inps[1]->shape()[1];
size_t pad_channels = 0;
if (out_channels <= 16) {
if (out_channels % 4)
pad_channels = 4 - (out_channels % 4);
} else {
if (out_channels % 32)
pad_channels = 32 - (out_channels % 32);
}
if (pad_channels > 0) {
inps[1] = pad_out_channels(inps[1], pad_channels);
inps[2] = pad_in_channels(inps[2], pad_channels);
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
// padding policy for conv bias with data type qint4 and quint4
auto padding_policy_int4 = [&padding_oprs, &pad_in_channels,
&pad_out_channels](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 3);
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
size_t new_in_channels = new_inp[0]->shape()[1];
// pad input channels
if (padding_oprs.count(opr->input(0)->owner_opr())) {
if (new_in_channels <= 32) {
if (new_in_channels % 8 == 0) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels_0 = 8 - (new_in_channels % 8);
size_t pad_channels_1 = 8 - (in_channels % 8);
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
}
} else {
if (new_in_channels % 64 == 0) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels_0 = 64 - (new_in_channels % 64);
size_t pad_channels_1 = 64 - (in_channels % 64);
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
}
}
} else {
size_t pad_channels = 0;
mgb_assert(new_in_channels == in_channels);
if (in_channels <= 32) {
if (in_channels % 8)
pad_channels = 8 - (in_channels % 8);
} else {
if (in_channels % 64)
pad_channels = 64 - (in_channels % 64);
}
if (pad_channels > 0) {
inps[0] = pad_in_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
}
out_channels = inps[1]->shape()[0];
in_channels = inps[1]->shape()[1];
size_t pad_channels = 0;
if (out_channels <= 32) {
if (out_channels % 8)
pad_channels = 8 - (out_channels % 8);
} else {
if (out_channels % 64)
pad_channels = 64 - (out_channels % 64);
}
if (pad_channels > 0) {
inps[1] = pad_out_channels(inps[1], pad_channels);
inps[2] = pad_in_channels(inps[2], pad_channels);
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
opr_replace_funcs[opr::ConvBiasForward::typeinfo()] =
[&padding_oprs, &padding_policy_qint8, &padding_policy_int4](
OperatorNodeBase* opr, const VarNodeArray& new_inp) {
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
return padding_policy_qint8(opr, new_inp);
} else if (opr->input(0)->dtype().enumv() ==
DTypeEnum::QuantizedS4 ||
opr->input(0)->dtype().enumv() ==
DTypeEnum::Quantized4Asymm) {
return padding_policy_int4(opr, new_inp);
} else {
mgb_assert(
padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"conv bias operator for data type(%s) cannot be "
"padded channel. "
"consumer(%s), producer(%s)",
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
};
opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] =
[&padding_oprs, &pad_in_channels, &pad_out_channels](
OperatorNodeBase* opr, const VarNodeArray& new_inp) {
if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) {
mgb_assert(
padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"conv bwd data operator for data type(%s) cannot "
"be "
"padded channel. "
"consumer(%s), producer(%s)",
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 2,
"deconv (conv bwd data) operator for inference can "
"only have 2 input vars(got:%zu)",
new_inp.size());
mgb_assert(
opr->input(0)->shape().eq_shape(new_inp[0]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(0)->shape()[0];
size_t in_channels = opr->input(0)->shape()[1];
size_t new_out_channels = new_inp[1]->shape()[1];
// pad output channels
if (padding_oprs.count(opr->input(1)->owner_opr())) {
size_t pad_channels = new_out_channels - out_channels;
inps[0] = pad_out_channels(new_inp[0], pad_channels);
} else {
size_t pad_channels = 0;
if (out_channels % 4)
pad_channels = 4 - (out_channels % 4);
if (pad_channels > 0) {
inps[0] = pad_out_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
}
out_channels = inps[0]->shape()[0];
in_channels = inps[0]->shape()[1];
// pad input channels
size_t pad_channels = 0;
if (in_channels % 4)
pad_channels = 4 - (in_channels % 4);
if (pad_channels > 0) {
inps[0] = pad_in_channels(inps[0], pad_channels);
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps,
opr->config());
};
auto replace_format_aware_opr = [&padding_oprs](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 &&
opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 &&
opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) {
mgb_assert(padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"operator(type:%s,name:%s) for data type(%s) cannot be "
"padded channel. extra info:"
"consumer(%s), producer(%s)",
opr->dyn_typeinfo()->name, opr->cname(),
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
if (padding_oprs.count(opr->input(0)->owner_opr())) {
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
opr_replace_funcs[opr::PoolingForward::typeinfo()] =
replace_format_aware_opr;
opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] =
replace_format_aware_opr;
auto replace_elemwise_like_opr = [&padding_oprs, &extract_subtensor](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
bool have_padding_inp = false;
bool padding_all_inps = true;
bool same_padding = true;
size_t channels_after_padding = 0;
size_t i = 0;
for (auto&& cur_inp : opr->input()) {
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
if (!have_padding_inp)
have_padding_inp = true;
if (channels_after_padding == 0) {
channels_after_padding = new_inp[i]->shape()[1];
} else {
same_padding =
channels_after_padding == new_inp[i]->shape()[1];
}
}
if (padding_all_inps && (!padding_cur_inp || !same_padding))
padding_all_inps = false;
++i;
}
if (have_padding_inp && !padding_all_inps) {
auto inps = new_inp;
for (size_t i = 0; i < new_inp.size(); ++i) {
auto cur_inp = opr->input(i);
bool padding_cur_inp =
padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
inps[i] = extract_subtensor(inps[i], cur_inp->shape());
}
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
}
if (padding_all_inps) {
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] =
replace_elemwise_like_opr;
opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;
auto replace_nonpadding_oprs = [&padding_oprs, &extract_subtensor](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto inps = new_inp;
for (size_t i = 0; i < new_inp.size(); ++i) {
auto cur_inp = opr->input(i);
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
inps[i] = extract_subtensor(inps[i], cur_inp->shape());
}
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs;
auto on_opr = [&opt, &rewriter, &opr_replace_funcs,
&extract_subtensor](OperatorNodeBase* opr) {
auto it = opr_replace_funcs.find(opr->dyn_typeinfo());
if (it != opr_replace_funcs.end()) {
VarNodeArray new_inp;
new_inp.reserve(opr->input().size());
for (auto&& inp : opr->input()) {
new_inp.push_back(rewriter.get_var(inp));
}
auto new_opr = (it->second)(opr, new_inp);
auto &&out0 = opr->output(), &&out1 = new_opr->output();
mgb_assert(out0.size() == out1.size(),
"bad opr replace: src=%s{%s} dst=%s{%s}, "
"src.size=%zu "
"dst.size=%zu",
opr->cname(), opr->dyn_typeinfo()->name,
new_opr->cname(), new_opr->dyn_typeinfo()->name,
out0.size(), out1.size());
for (size_t i = 0; i < out0.size(); ++i) {
if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
mgb_assert(!out1[i]->contain_flag(
VarNode::Flag::VOLATILE_CONTENT));
auto src = out0[i];
auto dst = out1[i];
if (opt.graph().endpoint_contain(src) &&
!src->shape().eq_shape(dst->shape())) {
dst = extract_subtensor(dst, src->shape());
}
rewriter.replace_var(src, dst, nullptr);
}
}
} else {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -11,7 +11,6 @@
*/
#include "megbrain/gopt/reformat_manager.h"
#include <numeric>
#include "megbrain/opr/tensor_manip.h"
using namespace mgb;
......@@ -65,6 +64,10 @@ NamedTensorShape tensor_formats_to_named_tensor_shape(TensorFormats format) {
return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}};
case TensorFormats::KRSCk8:
return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}};
case TensorFormats::KCRSc4:
return {{"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}};
case TensorFormats::GKCRSc4:
return {{"G"}, {"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}};
case TensorFormats::KCRS:
return {{"K"}, {"C"}, {"R"}, {"S"}};
case TensorFormats::GKCRS:
......@@ -130,70 +133,40 @@ bool ReformatManager::ReformatKey::Equal::operator()(
lhs.attribute == rhs.attribute;
}
ReformatManager::ReformatKey&
ReformatManager::ReformatKey::deduce_reformat_dtype_enum(const DType& dt) {
static const ThinHashSet<std::pair<TensorFormats, TensorFormats>> set = {
{TensorFormats::NCHW, TensorFormats::NCHWc64},
{TensorFormats::NCHWc64, TensorFormats::NCHW},
{TensorFormats::NCHW, TensorFormats::NHWC},
{TensorFormats::NHWC, TensorFormats::NCHW}};
if (set.count({input_format, output_format}) > 0 &&
(dt.enumv() == DTypeEnum::QuantizedS4 ||
dt.enumv() == DTypeEnum::Quantized4Asymm)) {
input_dtype = output_dtype = dt.enumv();
}
return *this;
}
// =================== ReformatManager ====================*/
#define FOREACH_FEATURE_TENSOR_FORMATS(cb) \
cb(NCHW) cb(NHWC) cb(NCHWc4) cb(NCHWc8) cb(NCHWc32) cb(NCHWc64) cb(CHWNc4) \
cb(NHCWc4)
#define FOREACH_WEIGHT_TENSOR_FORMATS(cb) \
cb(KRSCk4) cb(KRSCk4c4) cb(KCRSk4c4) cb(KCRSc4k4) cb(KCRSc8k8) cb(KRSCk8) \
cb(GKRSCk4) cb(GKRSCk4c4) cb(GKCRSc4k4) cb(GKCRSk4c4) \
cb(GKCRSc8k8) cb(C11RSc4) cb(C11RSc8)
ReformatManager::ReformatManager() {
static constexpr TensorFormats feature_tensor_formats[] = {
#define cb(_fmt) TensorFormats::_fmt,
FOREACH_FEATURE_TENSOR_FORMATS(cb)
#undef cb
};
static constexpr int nr_feature_tensor_formats =
sizeof(feature_tensor_formats) / sizeof(TensorFormats);
for (int i = 0; i < nr_feature_tensor_formats; ++i) {
for (int o = 0; o < nr_feature_tensor_formats; ++o) {
if (i == o)
continue;
NamedTensorShape input_shape = tensor_formats_to_named_tensor_shape(
feature_tensor_formats[i]);
NamedTensorShape output_shape =
tensor_formats_to_named_tensor_shape(
feature_tensor_formats[o]);
auto impl = std::get<0>(
ReformatEmitter{input_shape, output_shape}.emit());
m_cache.emplace(ReformatKey{feature_tensor_formats[i],
feature_tensor_formats[o]},
impl);
}
}
static constexpr TensorFormats default_weight_tensor_formats =
TensorFormats::KCRS;
static constexpr TensorFormats default_group_conv_weight_tensor_formats =
TensorFormats::GKCRS;
static constexpr TensorFormats default_chan_conv_weight_tensor_formats =
TensorFormats::C11RS;
static constexpr TensorFormats weight_tensor_formats[] = {
#define cb(_fmt) TensorFormats::_fmt,
FOREACH_WEIGHT_TENSOR_FORMATS(cb)
#undef cb
};
static constexpr int nr_weight_tensor_formats =
sizeof(weight_tensor_formats) / sizeof(TensorFormats);
using Name = megdnn::Dimension::Name;
for (int o = 0; o < nr_weight_tensor_formats; ++o) {
NamedTensorShape output_shape =
tensor_formats_to_named_tensor_shape(weight_tensor_formats[o]);
TensorFormats input_format;
if (output_shape[0].name() == Name::G) {
input_format = default_group_conv_weight_tensor_formats;
} else if (output_shape[0].name() == Name::C) {
input_format = default_chan_conv_weight_tensor_formats;
} else {
mgb_assert(output_shape[0].name() == Name::K);
input_format = default_weight_tensor_formats;
}
NamedTensorShape input_shape =
tensor_formats_to_named_tensor_shape(input_format);
auto impl =
std::get<0>(ReformatEmitter{input_shape, output_shape}.emit());
m_cache.emplace(ReformatKey{input_format, weight_tensor_formats[o]},
impl);
using Attribute = ReformatKey::Attribute;
{
auto i = TensorFormats::NCHWc4, o = TensorFormats::CHWNc4;
auto&& impl1 = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0],
megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4)
.node();
};
m_cache.emplace(ReformatKey{i, o}, impl1);
auto&& impl2 = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0],
megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4)
.node();
};
m_cache.emplace(ReformatKey{o, i}, impl2);
}
{
auto i = TensorFormats::NCHW, o = TensorFormats::NCHWc4;
......@@ -206,7 +179,7 @@ ReformatManager::ReformatManager() {
m_cache.emplace(ReformatKey{i, o, Attribute::IC_SMALL}, impl);
}
{
auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4k4;
auto i = TensorFormats::KCRS, o = TensorFormats::KCRSc4;
auto&& impl = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0],
......@@ -238,7 +211,7 @@ ReformatManager::ReformatManager() {
auto&& impl = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0],
megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64)
megdnn::param::RelayoutFormat::Mode::NCHW64_NCHW)
.node();
};
m_cache.emplace(
......@@ -272,7 +245,7 @@ ReformatManager::ReformatManager() {
auto&& impl = [](const VarNodeArray& vars) {
return opr::RelayoutFormat::make(
vars[0],
megdnn::param::RelayoutFormat::Mode::NCHW_NHWC)
megdnn::param::RelayoutFormat::Mode::NHWC_NCHW)
.node();
};
m_cache.emplace(
......@@ -371,14 +344,23 @@ ReformatManager::ReformatManager() {
impl);
}
}
#undef FOREACH_FEATURE_TENSOR_FORMATS
#undef FOREACH_WEIGHT_TENSOR_FORMATS
const ReformatManager::ReformatImpl& ReformatManager::get(
ReformatManager::ReformatImpl ReformatManager::get(
const ReformatKey& key) const {
using Attribute = ReformatKey::Attribute;
MGB_TRY {
auto&& impl = m_cache.at(key);
return impl;
auto find = m_cache.find(key);
if (find != m_cache.end()) {
auto rst = find->second;
return rst;
}
mgb_assert(key.attribute == Attribute::DEFAULT);
auto&& i = key.input_format;
auto&& o = key.output_format;
auto ishp = tensor_formats_to_named_tensor_shape(i);
auto oshp = tensor_formats_to_named_tensor_shape(o);
auto builder = std::get<0>(ReformatEmitter{ishp, oshp}.emit());
return builder;
}
MGB_CATCH(std::exception & exc, {
mgb_log_error(
......@@ -390,10 +372,7 @@ const ReformatManager::ReformatImpl& ReformatManager::get(
}
const ReformatManager& ReformatManager::instance() {
static ReformatManager* inst = nullptr;
if (inst == nullptr) {
inst = new ReformatManager();
}
return *inst;
static ReformatManager inst;
return inst;
}
// vim: syntax=cpp.doxygen
......@@ -42,6 +42,10 @@
#include "midout.h"
#include "megbrain/gopt/reformat_manager.h"
#include "./global_layout_transform/utils.h"
MIDOUT_DECL(megbrain_tensor_reformat)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_tensor_reformat, midout_iv(MGB_HASH_STR(tag))) {
......@@ -51,6 +55,7 @@ MIDOUT_DECL(megbrain_tensor_reformat)
using namespace mgb;
using namespace gopt;
using ReformatKey = ReformatManager::ReformatKey;
/* ================ TensorReformatPass =============== */
/*!
......@@ -67,99 +72,40 @@ using namespace gopt;
* representations before being translated to MegBrain oprs, so the
* oprs should not get involved in any actual computing.
*/
// clang-format off
MGB_DEFINE_OPR_CLASS(TensorReformatPass::RelayoutPlaceholder,
cg::SingleCNOperatorNodeBase) // {
cg::SingleCNOperatorNodeBase) // {
public:
//! relayout type of this opr
enum class LayoutType {
NCHW4_TO_NCHW32, //!< from nchw4 layout to nchw32 layout
NCHW32_TO_NCHW4, //!< from nchw32 layout to nchw4 layout
NCHW4_TO_CHWN4, //!< from nchw4 layout to chwn4 layout
CHWN4_TO_NCHW4, //!< from chwn4 layout to nchw4 layout
NCHW_TO_NCHW4, //!< from nchw layout to nchw4 layout
NCHW_TO_NCHW4_IC_SMALL_CONV, ///< from nchw layout to nchw4 whose
///< channel size less than 4
NCHW4_TO_NCHW, //!< from nchw4 layout to nchw layout
NCHW_TO_NCHW88, //!< from nchw layout to nchw88 layout
NCHW88_TO_NCHW, //!< from nchw88 layout to nchw layout
WEIGHT_NCHW_TO_NCHW4_DENSE, //!< weight from nchw layout to nchw4
//!< layout
WEIGHT_NCHW_TO_NCHW4_GROUP, //!< group weight from nchw layout to
//!< nchw4 layout
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV, //!< weight from nchw layout
//!< to nchw4 layout whose
//! channel size less than 4
WEIGHT_NCHW_TO_NCHW88_DENSE, //!< weight from nchw layout to nchw88
//!< layout
WEIGHT_NCHW_TO_NCHW88_GROUP, //!< group weight from nchw layout to
//!< nchw88 layout
WEIGHT_NCHW_TO_NCHW88_CHAN, //!< channel wise weight from nchw layout
//!< to nchw88 layout
//!< the weight layout of input is nchw output is nchw88, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {8, 3, 3, 2, 8}
WEIGHT_HYBIRD_NCHW_NCHW88,
WEIGHT_NCHW_TO_NCHW44_DENSE, //!< weight from nchw layout to nchw44
//!< layout
WEIGHT_NCHW_TO_NCHW44_GROUP, //!< group weight from nchw layout to
//!< nchw44 layout
WEIGHT_NCHW_TO_NCHW44_CHAN, //!< channel wise weight from nchw layout
//!< to nchw44 layout
//!< the weight layout of input is nchw output is nchw44, special for
//!< shape weight in nchw like {64, 2, 3, 3} to {16, 3, 3, 2, 4}
WEIGHT_HYBIRD_NCHW_NCHW44,
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE, //!< weight from NCHW44 layout to
//!< NCHW44_DOT layout dense
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP, //!< weight from NCHW44 layout to
//!< NCHW44_DOT layout group
NCHW32_TO_NCHW, //! <from nchw32 layout to nchw layout
NCHW32_TO_NCHW64, //! <from nchw32 layout to nchw64 layout
NCHW64_TO_NCHW, //! <from nchw64 layout to nchw layout
NCHW64_TO_NCHW4, //! <from nchw64 layout to nchw4 layout
NCHW64_TO_NCHW32, //! <from nchw64 layout to nchw32 layout
NCHW_TO_NCHW64, //! <from nchw layout to nchw64 layout
NCHW_TO_NCHW32, //! <from nchw layout to nchw64 layout
NCHW4_TO_NCHW64, //! <from nchw4 layout to nchw64 layout
NCHW_TO_NHWC, //! <NHWC related layout transformation
NCHW4_TO_NHWC,
NCHW32_TO_NHWC,
NCHW64_TO_NHWC,
NHWC_TO_NCHW,
NHWC_TO_NCHW4,
NHWC_TO_NCHW32,
NHWC_TO_NCHW64,
};
RelayoutPlaceholder(VarNode* src_var, LayoutType layout_type);
/*!
* \param src_var the input var
* \param layout_type tensor layout transform type of this relayout
* placeholder as described in LayoutType
*/
static SymbolVar make(VarNode* src_var, LayoutType layout_type);
RelayoutPlaceholder(VarNode* src_var, const ReformatKey& key);
LayoutType layout_type() const {
return m_layout_type;
}
/*!
* \param src_var the input var
* \param layout_type tensor layout transform type of this relayout
* placeholder as described in LayoutType
*/
static SymbolVar make(VarNode* src_var, const ReformatKey& key);
const ReformatKey& key() const {
return m_key;
}
private:
void init_output_static_infer_desc() override;
void scn_do_execute() override;
void init_output_comp_node() override;
const LayoutType m_layout_type;
}
;
void init_output_static_infer_desc() override;
void scn_do_execute() override;
void init_output_comp_node() override;
const ReformatKey m_key;
VarNode* m_output;
};
MGB_DYN_TYPE_OBJ_FINAL_IMPL(TensorReformatPass::RelayoutPlaceholder);
// clang-format on
TensorReformatPass::RelayoutPlaceholder::RelayoutPlaceholder(
VarNode* src_var, LayoutType layout_type)
VarNode* src_var, const ReformatKey& key)
: Super(src_var->owner_graph(), {}, "RelayoutPlaceholder", {src_var}),
m_layout_type{layout_type} {
m_key{key} {
add_input({src_var});
add_equivalence_component<ScalarHash<LayoutType>>(m_layout_type);
add_equivalence_component<PODHash<ReformatKey>>(&m_key);
m_output = ReformatManager::instance().get(m_key)({src_var});
add_output(None)->dtype(src_var->dtype());
}
......@@ -174,360 +120,13 @@ void TensorReformatPass::RelayoutPlaceholder::init_output_comp_node() {
void TensorReformatPass::RelayoutPlaceholder::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
DepVal deps;
for (auto i : input())
deps.push_back({i, DepType::SHAPE});
auto infer_shape = [this](TensorShape& dst, const InpVal& inp) {
TensorShape inp_shape = inp.val[0].shape();
dst = inp_shape;
if (layout_type() == RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 8;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = inp_shape[4] * 8;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32);
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 8;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = inp_shape[4] / 8;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst[0] = inp_shape[1];
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[0];
dst[4] = inp_shape[4];
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst[0] = inp_shape[3];
dst[1] = inp_shape[0];
dst[2] = inp_shape[1];
dst[3] = inp_shape[2];
dst[4] = inp_shape[4];
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4 ||
layout_type() == RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW4_IC_SMALL_CONV) {
if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0,
"src shape %s", inp_shape.to_string().c_str());
} else {
mgb_assert(layout_type() ==
RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW4_IC_SMALL_CONV);
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4);
}
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = (inp_shape[1] + 4 - 1) / 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_DENSE ||
layout_type() ==
RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV) {
if (layout_type() ==
RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 4 == 0);
} else {
mgb_assert(layout_type() ==
RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV);
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] < 4);
}
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = (inp_shape[1] + 4 - 1) / 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW4_GROUP) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[2] % 4 == 0);
dst.ndim = 6;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1];
dst[2] = inp_shape[2] / 4;
dst[3] = inp_shape[3];
dst[4] = inp_shape[4];
dst[5] = 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW88) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 8 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 8;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 8;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 8);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 8;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW88_DENSE) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0 &&
inp_shape[1] % 8 == 0);
dst.ndim = 6;
dst[0] = inp_shape[0] / 8;
dst[1] = inp_shape[1] / 8;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 8;
dst[5] = 8;
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW88_GROUP) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 8 == 0 &&
inp_shape[2] % 8 == 0);
dst.ndim = 7;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 8;
dst[2] = inp_shape[2] / 8;
dst[3] = inp_shape[3];
dst[4] = inp_shape[4];
dst[5] = 8;
dst[6] = 8;
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW88_CHAN) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 &&
inp_shape[2] == 1 && inp_shape[0] % 8 == 0);
dst.ndim = 6;
dst[0] = inp_shape[0] / 8;
dst[1] = inp_shape[1];
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = inp_shape[4];
dst[5] = 8;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 8 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0] / 8;
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[1];
dst[4] = 8;
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_DENSE ||
layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_DOT_DENSE) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0 &&
inp_shape[1] % 4 == 0);
dst.ndim = 6;
dst[0] = inp_shape[0] / 4;
dst[1] = inp_shape[1] / 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
dst[5] = 4;
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_GROUP ||
layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_DOT_GROUP) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 4 == 0 &&
inp_shape[2] % 4 == 0);
dst.ndim = 7;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 4;
dst[2] = inp_shape[2] / 4;
dst[3] = inp_shape[3];
dst[4] = inp_shape[4];
dst[5] = 4;
dst[6] = 4;
} else if (layout_type() == RelayoutPlaceholder::LayoutType::
WEIGHT_NCHW_TO_NCHW44_CHAN) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] == 1 &&
inp_shape[2] == 1 && inp_shape[0] % 4 == 0);
dst.ndim = 6;
dst[0] = inp_shape[0] / 4;
dst[1] = inp_shape[1];
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = inp_shape[4];
dst[5] = 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[0] % 4 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0] / 4;
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[1];
dst[4] = 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 32;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW64) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 2 == 0 &&
inp_shape[4] == 32);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 2;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 64;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 64);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 64;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW4) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 64);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 16;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW32) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 64);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 2;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 32;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW64) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 64 == 0, "%s",
inp_shape.to_string().c_str());
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 64;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 64;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW32) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[1] % 32 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 32;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 32;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[1] % 16 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 16;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 64;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC) {
mgb_assert(inp_shape.ndim == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[1];
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW4_TO_NHWC) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[1] * 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW32_TO_NHWC) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[1] * 32;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NCHW64_TO_NHWC) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 64);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[1] * 64;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW) {
mgb_assert(inp_shape.ndim == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[3];
dst[2] = inp_shape[1];
dst[3] = inp_shape[2];
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW4) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 4 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[3] / 4;
dst[2] = inp_shape[1];
dst[3] = inp_shape[2];
dst[4] = 4;
} else if (layout_type() ==
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32) {
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 32 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[3] / 32;
dst[2] = inp_shape[1];
dst[3] = inp_shape[2];
dst[4] = 32;
} else {
mgb_assert(layout_type() ==
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64);
mgb_assert(inp_shape.ndim == 4 && inp_shape[3] % 64 == 0);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[3] / 64;
dst[2] = inp_shape[1];
dst[3] = inp_shape[2];
dst[4] = 64;
}
return true;
};
mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape});
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(m_output));
}
SymbolVar TensorReformatPass::RelayoutPlaceholder::make(
VarNode* src_var, LayoutType layout_type) {
VarNode* src_var, const ReformatKey& key) {
return src_var->owner_graph()
->insert_opr(
std::make_unique<RelayoutPlaceholder>(src_var, layout_type))
->insert_opr(std::make_unique<RelayoutPlaceholder>(src_var, key))
->output(0);
}
......@@ -576,541 +175,13 @@ void TensorReformatPass::insert_pass(OptState& opt) const {
}
void TensorReformatPass::translate_pass(OptState& opt) const {
ThinHashMap<RelayoutPlaceholder::LayoutType,
thin_function<VarNode*(VarNode*)>>
reformat;
using LayoutType = RelayoutPlaceholder::LayoutType;
reformat[LayoutType::NCHW4_TO_CHWN4] = [](VarNode* inp) -> VarNode* {
megdnn::param::RelayoutFormat param;
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4;
auto reformat = opr::RelayoutFormat::make(inp, param);
return reformat.node();
};
reformat[LayoutType::CHWN4_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
megdnn::param::RelayoutFormat param;
param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4;
auto reformat = opr::RelayoutFormat::make(inp, param);
return reformat.node();
};
reformat[LayoutType::NCHW4_TO_NCHW32] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW32_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW_TO_NCHW4_IC_SMALL_CONV] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto y = opr::RelayoutFormat::make(
x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL);
return y.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto y = opr::RelayoutFormat::make(
x, megdnn::param::RelayoutFormat::Mode::
NCHW_NCHW4_IC_SMALL_CONV_DENSE_WEIGHT);
return y.node();
};
reformat[LayoutType::NCHW_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
return y1.node();
};
reformat[LayoutType::NCHW4_TO_NCHW] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_DENSE] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 4, sub(2), sub(3), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2) / 4, cv(4), sub(3), sub(4)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1), sub(2) / 4, sub(3), sub(4), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 2, 4, 5, 3});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW_TO_NCHW88] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 8, cv(8), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 8, sub(2), sub(3), cv(8)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW88_TO_NCHW] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) * 8, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_DENSE] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0) / 8, cv(8), sub(1) / 8, cv(8), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0) / 8, sub(1) / 8, sub(2), sub(3), cv(8), cv(8)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_GROUP] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) / 8, cv(8), sub(2) / 8,
cv(8), sub(3), sub(4)},
0),
tshp1 = opr::Concat::make({sub(0), sub(1) / 8, sub(2) / 8, sub(3),
sub(4), cv(8), cv(8)},
0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW88_CHAN] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0) / 8, cv(8), sub(1), sub(2), sub(3), sub(4)}, 0),
tshp1 = opr::Concat::make(
{sub(0) / 8, sub(1), sub(2), sub(3), sub(4), cv(8)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW88] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0) / 8, cv(8), sub(1), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0) / 8, sub(2), sub(3), sub(1), cv(8)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DENSE] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 3, 1});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_GROUP] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4,
cv(4), sub(3), sub(4)},
0),
tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3),
sub(4), cv(4), cv(4)},
0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 4, 2});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_CHAN] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0) / 4, cv(4), sub(1), sub(2), sub(3), sub(4)}, 0),
tshp1 = opr::Concat::make(
{sub(0) / 4, sub(1), sub(2), sub(3), sub(4), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 2, 3, 4, 5, 1});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_HYBIRD_NCHW_NCHW44] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0) / 4, cv(4), sub(1), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0) / 4, sub(2), sub(3), sub(1), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 4, 2, 1});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0) / 4, cv(4), sub(1) / 4, cv(4), sub(2), sub(3)}, 0),
tshp1 = opr::Concat::make(
{sub(0) / 4, sub(1) / 4, sub(2), sub(3), cv(4), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 2, 4, 5, 1, 3});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4,
cv(4), sub(3), sub(4)},
0),
tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3),
sub(4), cv(4), cv(4)},
0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW32_TO_NCHW] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 =
opr::Concat::make({sub(0), sub(1) * 32, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::NCHW32_TO_NCHW64] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 2, cv(2), sub(2), sub(3), sub(4)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 2, sub(2), sub(3), sub(4) * 2}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW64_TO_NCHW] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 =
opr::Concat::make({sub(0), sub(1) * 64, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::NCHW64_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3), sub(4) / 4, cv(4)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) * 16, sub(2), sub(3), cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW64_TO_NCHW32] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3), sub(4) / 32, cv(32)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) * 2, sub(2), sub(3), cv(32)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW_TO_NCHW64] = [](VarNode* inp) -> VarNode* {
megdnn::param::RelayoutFormat param;
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NCHW64;
auto reformat = opr::RelayoutFormat::make(inp, param);
return reformat.node();
};
reformat[LayoutType::NCHW_TO_NCHW32] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 32, cv(32), sub(2), sub(3)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
return y1.node();
};
reformat[LayoutType::NCHW4_TO_NCHW64] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 16, cv(16), sub(2), sub(3), sub(4)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 16, sub(2), sub(3), sub(4) * 16}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
reformat[LayoutType::NCHW_TO_NHWC] = [](VarNode* inp) -> VarNode* {
megdnn::param::RelayoutFormat param;
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW_NHWC;
auto reformat = opr::RelayoutFormat::make(inp, param);
return reformat.node();
};
reformat[LayoutType::NCHW4_TO_NHWC] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::NCHW32_TO_NHWC] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 =
opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 32}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::NCHW64_TO_NHWC] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 =
opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 64}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4});
auto y1 = opr::Reshape::make(y0, tshp0);
return y1.node();
};
reformat[LayoutType::NHWC_TO_NCHW] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto y = opr::Dimshuffle::make(x, {0, 3, 1, 2});
return y.node();
};
reformat[LayoutType::NHWC_TO_NCHW4] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3) / 4, cv(4)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4});
return y1.node();
};
reformat[LayoutType::NHWC_TO_NCHW32] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3) / 32, cv(32)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4});
return y1.node();
};
reformat[LayoutType::NHWC_TO_NCHW64] = [](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4});
return y1.node();
};
auto rewriter = opt.graph().make_rewriter();
auto on_opr = [&reformat, &rewriter](OperatorNodeBase* opr) {
auto on_opr = [&rewriter](OperatorNodeBase* opr) {
if (opr->same_type<RelayoutPlaceholder>()) {
auto ph = try_cast_as_op<RelayoutPlaceholder>(opr);
auto new_inp = rewriter.get_var(opr->input(0));
mgb_assert(reformat.count(ph->layout_type()),
"no replace rule can be found for layout_type(%u)",
static_cast<uint32_t>(ph->layout_type()));
auto new_var = reformat[ph->layout_type()](new_inp);
auto new_var =
ReformatManager::instance().get(ph->key())({new_inp});
rewriter.replace_var(opr->output(0), new_var,
mgb_cstr_log("replace relayout placeholder"));
return;
......@@ -1132,9 +203,9 @@ void TensorReformatPass::apply(OptState& opt) const {
VarNode* EnableTensorCorePass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
if (!orig_var->shape().eq_shape(new_var->shape())) {
return RelayoutPlaceholder::make(
new_var,
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4)
return RelayoutPlaceholder::make(new_var,
ReformatKey{TensorFormats::NCHWc32,
TensorFormats::NCHWc4})
.node();
}
return new_var;
......@@ -1204,8 +275,8 @@ EnableTensorCorePass::make_tensorcore_converter() {
if (group == 1 && ocpg % 32 == 0 && icpg % 32 == 0 && ih >= 3 &&
iw >= 3) {
auto symvar = RelayoutPlaceholder::make(
new_inp[0],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
new_inp[0], ReformatKey{TensorFormats::NCHWc4,
TensorFormats::NCHWc32});
src = symvar.node();
can_replace_nchw32 = true;
} else {
......@@ -1232,8 +303,8 @@ EnableTensorCorePass::make_tensorcore_converter() {
src = new_inp[0];
} else {
auto symvar = RelayoutPlaceholder::make(
new_inp[0],
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
new_inp[0], ReformatKey{TensorFormats::NCHWc32,
TensorFormats::NCHWc4});
src = symvar.node();
}
}
......@@ -1241,7 +312,7 @@ EnableTensorCorePass::make_tensorcore_converter() {
if (can_replace_nchw32) {
auto symvar = RelayoutPlaceholder::make(
new_inp[1],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHWc32});
weight = symvar.node();
} else {
weight = new_inp[1];
......@@ -1265,8 +336,8 @@ EnableTensorCorePass::make_tensorcore_converter() {
if (can_replace_nchw32) {
if (is_nchw4(inp->shape())) {
auto symvar = RelayoutPlaceholder::make(
inp,
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
inp, ReformatKey{TensorFormats::NCHWc4,
TensorFormats::NCHWc32});
return symvar.node();
} else {
mgb_assert(is_nchw32(inp->shape()));
......@@ -1278,8 +349,8 @@ EnableTensorCorePass::make_tensorcore_converter() {
} else {
mgb_assert(is_nchw32(inp->shape()));
auto symvar = RelayoutPlaceholder::make(
inp,
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
inp, ReformatKey{TensorFormats::NCHWc32,
TensorFormats::NCHWc4});
return symvar.node();
}
}
......@@ -1335,8 +406,9 @@ EnableTensorCorePass::make_tensorcore_converter() {
for (size_t i = 0; i < nr_inps; ++i) {
if (opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
auto symvar = RelayoutPlaceholder::make(
new_inp[i], RelayoutPlaceholder::LayoutType::
NCHW4_TO_NCHW32);
new_inp[i],
ReformatKey{TensorFormats::NCHWc4,
TensorFormats::NCHWc32});
inps[i] = symvar.node();
}
}
......@@ -1344,8 +416,8 @@ EnableTensorCorePass::make_tensorcore_converter() {
for (size_t i = 0; i < nr_inps; ++i) {
if (!opr->input(i)->shape().eq_shape(new_inp[i]->shape())) {
auto symvar = RelayoutPlaceholder::make(
new_inp[i], RelayoutPlaceholder::LayoutType::
NCHW32_TO_NCHW4);
new_inp[i], ReformatKey{TensorFormats::NCHWc32,
TensorFormats::NCHWc4});
inps[i] = symvar.node();
}
}
......@@ -1366,8 +438,8 @@ EnableTensorCorePass::make_tensorcore_converter() {
mgb_assert(new_inp[i]->shape().ndim == 5 &&
new_inp[i]->shape()[4] == 32);
auto symvar = RelayoutPlaceholder::make(
new_inp[i],
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
new_inp[i], ReformatKey{TensorFormats::NCHWc32,
TensorFormats::NCHWc4});
inps[i] = symvar.node();
}
}
......@@ -1447,8 +519,8 @@ EnableTensorCorePass::make_tensorcore_converter() {
if (opr->input(0)->shape().eq_shape(new_inp[0]->shape())) {
new_inp_var =
RelayoutPlaceholder::make(
new_inp[0], RelayoutPlaceholder::LayoutType::
NCHW4_TO_NCHW32)
new_inp[0], ReformatKey{TensorFormats::NCHWc4,
TensorFormats::NCHWc32})
.node();
} else {
mgb_assert(opr->input(0)->shape().ndim == 5 &&
......@@ -1497,8 +569,9 @@ EnableTensorCorePass::make_tensorcore_converter() {
VarNode* EnableCHWN4Pass::on_graph_endpoint_var(VarNode* new_var,
VarNode* /* orig_var */) const {
if (m_varshape_changed.count(new_var)) {
return RelayoutPlaceholder::make(
new_var, RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4)
return RelayoutPlaceholder::make(new_var,
ReformatKey{TensorFormats::CHWNc4,
TensorFormats::NCHWc4})
.node();
}
return new_var;
......@@ -1547,7 +620,7 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
// currently not support group conv
auto symvar = RelayoutPlaceholder::make(
new_inp[0],
RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
ReformatKey{TensorFormats::NCHWc4, TensorFormats::CHWNc4});
src = symvar.node();
} else { // new input is NCHW32 layout
src = new_inp[0];
......@@ -1556,7 +629,7 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
{
auto symvar = RelayoutPlaceholder::make(
new_inp[1],
RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
ReformatKey{TensorFormats::NCHWc4, TensorFormats::CHWNc4});
weight = symvar.node();
}
if (new_inp.size() == 2) {
......@@ -1571,7 +644,8 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
auto process_inp = [&](VarNode* inp) -> VarNode* {
if (varshape_changed.count(inp) == 0) {
auto symvar = RelayoutPlaceholder::make(
inp, RelayoutPlaceholder::LayoutType::NCHW4_TO_CHWN4);
inp, ReformatKey{TensorFormats::NCHWc4,
TensorFormats::CHWNc4});
return symvar.node();
} else {
return inp;
......@@ -1618,8 +692,8 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
for (size_t i = 0; i < nr_inps; ++i) {
if (varshape_changed.count(new_inp[i]) == 0) {
auto symvar = RelayoutPlaceholder::make(
new_inp[i], RelayoutPlaceholder::LayoutType::
NCHW4_TO_CHWN4);
new_inp[i], ReformatKey{TensorFormats::NCHWc4,
TensorFormats::CHWNc4});
inps[i] = symvar.node();
}
}
......@@ -1631,8 +705,8 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
for (size_t i = 0; i < nr_inps; ++i) {
if (varshape_changed.count(new_inp[i])) {
auto symvar = RelayoutPlaceholder::make(
new_inp[i], RelayoutPlaceholder::LayoutType::
CHWN4_TO_NCHW4);
new_inp[i], ReformatKey{TensorFormats::CHWNc4,
TensorFormats::NCHWc4});
inps[i] = symvar.node();
}
}
......@@ -1651,8 +725,8 @@ std::unique_ptr<EnableCHWN4Pass> EnableCHWN4Pass::make_chwn4_converter() {
for (size_t i = 0; i < opr->input().size(); ++i) {
if (varshape_changed.count(new_inp[i])) {
auto symvar = RelayoutPlaceholder::make(
new_inp[i],
RelayoutPlaceholder::LayoutType::CHWN4_TO_NCHW4);
new_inp[i], ReformatKey{TensorFormats::CHWNc4,
TensorFormats::NCHWc4});
inps[i] = symvar.node();
}
}
......@@ -1770,7 +844,8 @@ VarNode* EnableNCHW4Pass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
if (!orig_var->shape().eq_shape(new_var->shape())) {
return RelayoutPlaceholder::make(
new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW)
new_var,
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHW})
.node();
}
return new_var;
......@@ -1781,7 +856,6 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
MIDOUT_B("EnableNCHW4Pass::make")
auto ret = std::make_unique<EnableNCHW4Pass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::NOCHECK);
using RelayoutMode = RelayoutPlaceholder::LayoutType;
megdnn::param::Convolution::Format conv_format =
megdnn::param::Convolution::Format::NCHW4;
megdnn::param::ConvBias::Format conv_bias_format =
......@@ -1790,16 +864,16 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
megdnn::param::ConvBias::Format::NCHW4_NCHW;
megdnn::param::BatchConvBias::Format batch_conv_bias_format =
megdnn::param::BatchConvBias::Format::NCHW4;
RelayoutMode src_to_nchw4_mode = RelayoutMode::NCHW_TO_NCHW4;
RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW;
RelayoutMode weight_to_nchw4_mode_dense =
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE;
RelayoutMode weight_to_nchw4_mode_group =
RelayoutMode::WEIGHT_NCHW_TO_NCHW4_GROUP;
ReformatKey src_to_nchw4_mode{TensorFormats::NCHW, TensorFormats::NCHWc4};
ReformatKey src_to_nchw_mode{TensorFormats::NCHWc4, TensorFormats::NCHW};
ReformatKey weight_to_nchw4_mode_dense{TensorFormats::KCRS,
TensorFormats::KCRSc4};
ReformatKey weight_to_nchw4_mode_group{TensorFormats::GKCRS,
TensorFormats::GKCRSc4};
struct ConvMode {
RelayoutMode weight;
RelayoutMode src;
ReformatKey weight;
ReformatKey src;
};
auto trans_nchw4 =
......@@ -1812,8 +886,11 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
"The origin filter is not NCHW mode");
size_t IC = filter->shape()[1];
if (IC < 4) {
return {RelayoutMode::WEIGHT_NCHW_TO_NCHW4_DENSE_IC_SMALL_CONV,
RelayoutMode::NCHW_TO_NCHW4_IC_SMALL_CONV};
ReformatKey weight{TensorFormats::KCRS, TensorFormats::KCRSc4,
ReformatKey::Attribute::IC_SMALL};
ReformatKey src{TensorFormats::NCHW, TensorFormats::NCHWc4,
ReformatKey::Attribute::IC_SMALL};
return {weight, src};
} else {
return {weight_to_nchw4_mode_dense, src_to_nchw4_mode};
}
......@@ -1869,8 +946,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
};
auto replace_deconv_opr = [trans_nchw4, conv_format](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (new_inp[1]->dtype().enumv() == DTypeEnum::Float32) {
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
......@@ -1885,8 +962,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
opr->config());
}
VarNode *deconv_src = new_inp[1], *deconv_filter = new_inp[0];
auto deconv_mode =
trans_nchw4(deconv_opr.param().sparse, deconv_filter);
auto deconv_mode = trans_nchw4(deconv_opr.param().sparse, deconv_filter);
// src: NCHW --> NCWH4
if (deconv_src->shape().ndim != 5) {
mgb_assert(deconv_src->shape().ndim == 4);
......@@ -1940,8 +1016,7 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
}
// weight: BNCHW --> BNCHW4
// only support dense mode, which is similar with conv->group.
auto weight_mode =
RelayoutPlaceholder::LayoutType::WEIGHT_NCHW_TO_NCHW4_GROUP;
ReformatKey weight_mode{TensorFormats::GKCRS, TensorFormats::GKCRSc4};
auto new_filter = RelayoutPlaceholder::make(new_inp[1], weight_mode);
filter = new_filter.node();
// format: NCHW --> NCHW4
......@@ -2033,10 +1108,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
conv_bias_src, conv_bias_filter, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.node()->dtype().enumv() ==
DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
mgb_assert(
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
}
// bias: NCHW --> NCHW4 when bias_dtype is not Float32
......@@ -2052,10 +1127,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
conv_bias_src, conv_bias_filter, conv_bias_bias, new_param,
conv_bias_opr.execution_policy(), conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.node()->dtype().enumv() ==
DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
mgb_assert(
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
}
// z_inp: NCHW --> NCHW4 when bias_dtype is not Float32
......@@ -2071,10 +1146,10 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
new_param, conv_bias_opr.execution_policy(),
conv_bias_opr.config());
OperatorNodeBase* new_opr = new_conv_bias_opr.node()->owner_opr();
mgb_assert(new_conv_bias_opr.node()->dtype().enumv() ==
DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
mgb_assert(
new_conv_bias_opr.node()->dtype().enumv() == DTypeEnum::Float32 ||
new_conv_bias_opr.shape().ndim == 5,
"The conv_bias dst dim is not trans to nchw4");
return new_opr;
};
auto replace_elemwise_opr = [=](OperatorNodeBase* opr,
......@@ -2215,7 +1290,8 @@ std::unique_ptr<EnableNCHW4Pass> EnableNCHW4Pass::make_nchw4_converter() {
auto&& replace_func = ret->m_opr_replace_func;
//! supportted nchw4
replace_func[opr::Convolution::typeinfo()] = replace_conv_opr;
replace_func[opr::ConvolutionBackwardData::typeinfo()] = replace_deconv_opr;
replace_func[opr::ConvolutionBackwardData::typeinfo()] =
replace_deconv_opr;
replace_func[opr::ConvBias::typeinfo()] = replace_conv_bias_opr;
replace_func[opr::BatchConvBias::typeinfo()] = replace_batch_conv_bias_opr;
replace_func[opr::PoolingForward::typeinfo()] = replace_pooling_opr;
......@@ -2244,14 +1320,14 @@ VarNode* EnableNchwxxPass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
if (!orig_var->shape().eq_shape(new_var->shape())) {
if (m_pack_c_size == 8) {
return RelayoutPlaceholder::make(
new_var,
RelayoutPlaceholder::LayoutType::NCHW88_TO_NCHW)
return RelayoutPlaceholder::make(new_var,
ReformatKey{TensorFormats::NCHWc8,
TensorFormats::NCHW})
.node();
} else if (m_pack_c_size == 4) {
return RelayoutPlaceholder::make(
new_var,
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW)
return RelayoutPlaceholder::make(new_var,
ReformatKey{TensorFormats::NCHWc4,
TensorFormats::NCHW})
.node();
}
}
......@@ -2335,17 +1411,16 @@ static inline bool nchw_nchwxx_valid(
}
void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
using RelayoutMode = RelayoutPlaceholder::LayoutType;
using TestFilterResult = std::pair<TransType, RelayoutMode>;
RelayoutMode weight_to_nchwxx_mode_dense =
RelayoutMode::WEIGHT_NCHW_TO_NCHW88_DENSE;
RelayoutMode weight_to_nchwxx_mode_group =
RelayoutMode::WEIGHT_NCHW_TO_NCHW88_GROUP;
RelayoutMode weight_to_nchwxx_mode_chan =
RelayoutMode::WEIGHT_NCHW_TO_NCHW88_CHAN;
RelayoutMode hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW88;
RelayoutMode src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW88;
RelayoutMode src_to_nchw_mode = RelayoutMode::NCHW88_TO_NCHW;
using TestFilterResult = std::pair<TransType, ReformatKey>;
ReformatKey weight_to_nchwxx_mode_dense{TensorFormats::KCRS,
TensorFormats::KCRSc8k8};
ReformatKey weight_to_nchwxx_mode_group{TensorFormats::GKCRS,
TensorFormats::GKCRSc8k8};
ReformatKey weight_to_nchwxx_mode_chan{TensorFormats::C11RS,
TensorFormats::C11RSc8};
ReformatKey hybrid_nchw_nchwxx{TensorFormats::KCRS, TensorFormats::KRSCk8};
ReformatKey src_to_nchwxx_mode{TensorFormats::NCHW, TensorFormats::NCHWc8};
ReformatKey src_to_nchw_mode{TensorFormats::NCHWc8, TensorFormats::NCHW};
megdnn::param::ConvBias::Format conv_bias_format =
megdnn::param::ConvBias::Format::NCHW88;
megdnn::param::Convolution::Format conv_format =
......@@ -2357,12 +1432,12 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
std::string convter_pass_name = "conv_format_nchw88";
if (pack_c_size == 4) {
weight_to_nchwxx_mode_dense = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE;
weight_to_nchwxx_mode_group = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP;
weight_to_nchwxx_mode_chan = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN;
hybrid_nchw_nchwxx = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
src_to_nchwxx_mode = RelayoutMode::NCHW_TO_NCHW4;
src_to_nchw_mode = RelayoutMode::NCHW4_TO_NCHW;
weight_to_nchwxx_mode_dense.output_format = TensorFormats::KCRSc4k4;
weight_to_nchwxx_mode_group.output_format = TensorFormats::GKCRSc4k4;
weight_to_nchwxx_mode_chan.output_format = TensorFormats::C11RSc4;
hybrid_nchw_nchwxx.output_format = TensorFormats::KRSCk4;
src_to_nchwxx_mode.output_format = TensorFormats::NCHWc4;
src_to_nchw_mode.input_format = TensorFormats::NCHWc4;
conv_bias_format = megdnn::param::ConvBias::Format::NCHW44;
conv_format = megdnn::param::Convolution::Format::NCHW44;
pooling_format = megdnn::param::Pooling::Format::NCHW44;
......@@ -2641,7 +1716,7 @@ void EnableNchwxxPass::fill_opr_convert_fun(size_t pack_c_size) {
return new_opr;
}
};
auto replace_resize_opr = [=](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
......@@ -2793,7 +1868,8 @@ VarNode* EnableNchw44DotPass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
if (!orig_var->shape().eq_shape(new_var->shape())) {
return RelayoutPlaceholder::make(
new_var, RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW)
new_var,
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHW})
.node();
}
return new_var;
......@@ -2807,10 +1883,9 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
//! First is whether the conv can trans to nchwxx, second is the filter
//! trans mode
using RelayoutMode = RelayoutPlaceholder::LayoutType;
struct TestTransResult {
TransType trans_type;
RelayoutMode relayout_mod;
ReformatKey relayout_mod;
megdnn::param::Convolution::Format conv_format;
};
constexpr size_t pack_c_size = 4_z;
......@@ -2828,18 +1903,19 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
if ((IC % pack_c_size == 0) && (OC % pack_c_size == 0)) {
ret.trans_type = TransType::TRANS_PURE_NCHWXX;
if (is_int8) {
ret.relayout_mod =
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_DENSE;
ret.relayout_mod = ReformatKey{TensorFormats::KCRS,
TensorFormats::KCRSk4c4};
ret.conv_format =
megdnn::param::ConvBias::Format::NCHW44_DOT;
} else {
ret.relayout_mod =
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DENSE;
ret.relayout_mod = ReformatKey{TensorFormats::KCRS,
TensorFormats::KCRSc4k4};
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
}
} else if (valid_nchw_nchw44) {
ret.trans_type = TransType::TRANS_HYBIRD_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_HYBIRD_NCHW_NCHW44;
ret.relayout_mod =
ReformatKey{TensorFormats::KCRS, TensorFormats::KRSCk4};
if (is_int8) {
ret.conv_format =
megdnn::param::ConvBias::Format::NCHW44_DOT;
......@@ -2855,18 +1931,19 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
size_t icpg = filter->shape()[2];
if (icpg == 1 && ocpg == 1 && (group % pack_c_size == 0)) {
ret.trans_type = TransType::TRANS_PURE_NCHWXX;
ret.relayout_mod = RelayoutMode::WEIGHT_NCHW_TO_NCHW44_CHAN;
ret.relayout_mod = ReformatKey{TensorFormats::C11RS,
TensorFormats::C11RSc4};
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
} else if ((icpg % pack_c_size == 0) && (ocpg % pack_c_size == 0)) {
ret.trans_type = TransType::TRANS_PURE_NCHWXX;
if (is_int8) {
ret.relayout_mod =
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_DOT_GROUP;
ret.relayout_mod = ReformatKey{TensorFormats::GKCRS,
TensorFormats::GKCRSk4c4};
ret.conv_format =
megdnn::param::ConvBias::Format::NCHW44_DOT;
} else {
ret.relayout_mod =
RelayoutMode::WEIGHT_NCHW_TO_NCHW44_GROUP;
ret.relayout_mod = ReformatKey{TensorFormats::GKCRS,
TensorFormats::GKCRSc4k4};
ret.conv_format = megdnn::param::ConvBias::Format::NCHW44;
}
}
......@@ -2898,7 +1975,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
//! if src is nchwxx, should RelayoutPlaceholder to nchw
if (temp_inp[0]->shape().ndim == 5) {
auto new_src = RelayoutPlaceholder::make(
new_inp[0], RelayoutMode::NCHW4_TO_NCHW);
new_inp[0], ReformatKey{TensorFormats::NCHWc4,
TensorFormats::NCHW});
temp_inp[0] = new_src.node();
}
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
......@@ -2917,7 +1995,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(
new_inp[0], RelayoutMode::NCHW_TO_NCHW4);
new_inp[0], ReformatKey{TensorFormats::NCHW,
TensorFormats::NCHWc4});
conv_src = new_src.node();
}
auto new_param = conv_opr.param();
......@@ -2986,14 +2065,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
//! if src is nchwxx, should RelayoutPlaceholder to nchw
if (temp_inp[0]->shape().ndim == 5) {
auto new_src = RelayoutPlaceholder::make(
new_inp[0], RelayoutMode::NCHW4_TO_NCHW);
new_inp[0],
ReformatKey{TensorFormats::NCHWc4, TensorFormats::NCHW});
temp_inp[0] = new_src.node();
}
//! the bias is nchwxx
if (new_inp.size() > 2 && temp_inp[2]->shape().ndim == 5) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW4_TO_NCHW);
new_inp[2], ReformatKey{TensorFormats::NCHWc4,
TensorFormats::NCHW});
temp_inp[2] = new_bias.node();
}
auto new_opr = serialization::copy_opr_shallow(*opr, temp_inp,
......@@ -3013,14 +2094,16 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
if (new_inp[0]->shape().ndim != 5) {
mgb_assert(new_inp[0]->shape().ndim == 4);
auto new_src = RelayoutPlaceholder::make(
new_inp[0], RelayoutMode::NCHW_TO_NCHW4);
new_inp[0], ReformatKey{TensorFormats::NCHW,
TensorFormats::NCHWc4});
conv_bias_src = new_src.node();
}
//! bias trans to nchwxx mode
if (new_inp.size() > 2) {
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
new_inp[2], ReformatKey{TensorFormats::NCHW,
TensorFormats::NCHWc4});
conv_bias_bias = new_bias.node();
} else {
mgb_assert(new_inp[2]->shape().ndim == 5);
......@@ -3059,7 +2142,8 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
if (new_inp.size() > 2) {
if (new_inp[2]->shape().ndim == 4) {
auto new_bias = RelayoutPlaceholder::make(
new_inp[2], RelayoutMode::NCHW_TO_NCHW4);
new_inp[2], ReformatKey{TensorFormats::NCHW,
TensorFormats::NCHWc4});
conv_bias_bias = new_bias.node();
} else {
mgb_assert(new_inp[2]->shape().ndim == 5);
......@@ -3099,303 +2183,21 @@ EnableNchw44DotPass::make_nchw44_dot_converter() {
/* ==================== ShuffleShuffleRemovePass ================= */
class ShuffleShuffleRemovePass::Impl {
using TensorFormat = opr::ConvBias::Param::Format;
using Format = opr::ConvBias::Param::Format;
OptState& m_opt_state;
ThinHashMap<std::pair<TensorFormat, TensorFormat>,
thin_function<VarNode*(VarNode*)>>
m_reformat;
class AbstractShuffleOpr;
using AbstractShuffleOpr = TensorReformatPass::RelayoutPlaceholder;
void detect_shuffle_operations();
void do_replace();
public:
Impl(OptState& opt_state) : m_opt_state{opt_state} {
m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW4)] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
auto y0 = opr::Reshape::make(x, tshp);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
return y1.node();
};
m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::NCHW32)] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp = opr::Concat::make(
{sub(0), sub(1) / 32, cv(32), sub(2), sub(3)}, 0);
auto y0 = opr::Reshape::make(x, tshp);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2});
return y1.node();
};
m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW)] =
[](VarNode* inp) -> VarNode* {
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp =
opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp);
return y1.node();
};
m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW)] =
[](VarNode* inp) -> VarNode* {
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32);
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp =
opr::Concat::make({sub(0), sub(1) * 32, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp);
return y1.node();
};
m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::NCHW32)] =
[](VarNode* inp) -> VarNode* {
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)},
0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
m_reformat[std::make_pair(TensorFormat::NCHW32, TensorFormat::NCHW4)] =
[](VarNode* inp) -> VarNode* {
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32);
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8},
0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
m_reformat[std::make_pair(TensorFormat::NCHW4, TensorFormat::CHWN4)] =
[](VarNode* inp) -> VarNode* {
megdnn::param::RelayoutFormat param;
param.mode = megdnn::param::RelayoutFormat::Mode::NCHW4_CHWN4;
auto reformat = opr::RelayoutFormat::make(inp, param);
return reformat.node();
};
m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW4)] =
[](VarNode* inp) -> VarNode* {
megdnn::param::RelayoutFormat param;
param.mode = megdnn::param::RelayoutFormat::Mode::CHWN4_NCHW4;
auto reformat = opr::RelayoutFormat::make(inp, param);
return reformat.node();
};
m_reformat[std::make_pair(TensorFormat::NCHW, TensorFormat::CHWN4)] =
[](VarNode* inp) -> VarNode* {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp = opr::Concat::make(
{sub(0), sub(1) / 4, cv(4), sub(2), sub(3)}, 0);
auto y0 = opr::Reshape::make(x, tshp);
auto y1 = opr::Dimshuffle::make(y0, {1, 3, 4, 0, 2});
return y1.node();
};
m_reformat[std::make_pair(TensorFormat::CHWN4, TensorFormat::NCHW)] =
[](VarNode* inp) -> VarNode* {
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp =
opr::Concat::make({sub(3), sub(0) * 4, sub(1), sub(2)}, 0);
auto y0 = opr::Dimshuffle::make(x, {3, 0, 4, 1, 2});
auto y1 = opr::Reshape::make(y0, tshp);
return y1.node();
};
detect_shuffle_operations();
do_replace();
}
};
/*!
* \brief abstract operator representation of shuffle operation
*/
MGB_DEFINE_OPR_CLASS(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr,
cg::SingleCNOperatorNodeBase) // {
public:
AbstractShuffleOpr(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);
static SymbolVar make(VarNode* inpvar, TensorFormat inp_format,
TensorFormat out_format);
TensorFormat inp_format() const {
return m_inp_format;
}
TensorFormat out_format() const {
return m_out_format;
}
private:
void init_output_static_infer_desc() override;
void scn_do_execute() override;
const TensorFormat m_inp_format;
const TensorFormat m_out_format;
}
;
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr);
void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::scn_do_execute() {
mgb_throw(InternalError, "AbstractShuffleOpr cannot be executed");
}
void ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::
init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
DepVal deps;
for (auto i : input())
deps.push_back({i, DepType::SHAPE});
auto infer_shape = [this](TensorShape& dst, const InpVal& inp) {
TensorShape inp_shape = inp.val[0].shape();
if (m_inp_format == TensorFormat::NCHW4 &&
m_out_format == TensorFormat::NCHW32) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst = inp_shape;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 8;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = inp_shape[4] * 8;
} else if (m_inp_format == TensorFormat::NCHW32 &&
m_out_format == TensorFormat::NCHW4) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 32);
dst = inp_shape;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 8;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = inp_shape[4] / 8;
} else if (m_inp_format == TensorFormat::NCHW &&
m_out_format == TensorFormat::NCHW4) {
mgb_assert(inp_shape.ndim == 4);
dst.ndim = 5;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] / 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
dst[4] = 4;
} else if (m_inp_format == TensorFormat::NCHW4 &&
m_out_format == TensorFormat::NCHW) {
mgb_assert(inp_shape.ndim == 5 && inp_shape[4] == 4);
dst.ndim = 4;
dst[0] = inp_shape[0];
dst[1] = inp_shape[1] * 4;
dst[2] = inp_shape[2];
dst[3] = inp_shape[3];
} else if (m_inp_format == TensorFormat::NCHW4 &&
m_out_format == TensorFormat::CHWN4) {
dst.ndim = 5;
dst[0] = inp_shape[1];
dst[1] = inp_shape[2];
dst[2] = inp_shape[3];
dst[3] = inp_shape[0];
dst[4] = inp_shape[4];
} else if (m_inp_format == TensorFormat::CHWN4 &&
m_out_format == TensorFormat::NCHW4) {
dst.ndim = 5;
dst[0] = inp_shape[3];
dst[1] = inp_shape[0];
dst[2] = inp_shape[1];
dst[3] = inp_shape[2];
dst[4] = inp_shape[4];
} else {
mgb_throw(InternalError,
"Unsupported input format and output format.");
}
return true;
};
mgr.register_shape_infer(output(0), {SourceType::DEP, deps, infer_shape});
}
ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::AbstractShuffleOpr(
VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format)
: Super(inpvar->owner_graph(), {}, "AbstractShuffleOpr", {inpvar}),
m_inp_format{inp_format},
m_out_format{out_format} {
add_input({inpvar});
add_equivalence_component<ScalarHash<TensorFormat>>(m_inp_format);
add_equivalence_component<ScalarHash<TensorFormat>>(m_out_format);
add_output(None)->dtype(inpvar->dtype());
}
SymbolVar ShuffleShuffleRemovePass::Impl::AbstractShuffleOpr::make(
VarNode* inpvar, TensorFormat inp_format, TensorFormat out_format) {
return inpvar->owner_graph()
->insert_opr(std::make_unique<AbstractShuffleOpr>(
inpvar, inp_format, out_format))
->output(0);
}
void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() {
auto rewriter = m_opt_state.graph().make_rewriter();
auto uniq_reader_check = UniqReaderCheck{m_opt_state.graph()};
......@@ -3423,7 +2225,8 @@ void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() {
return false;
auto inp_var = rewriter.get_var(reshape->input(0));
auto abstract_shuffle = AbstractShuffleOpr::make(
inp_var, TensorFormat::NCHW, TensorFormat::NCHW4);
inp_var,
ReformatKey{TensorFormats::NCHW, TensorFormats::NCHWc4});
rewriter.replace_var(
opr->output(0), abstract_shuffle.node(),
mgb_cstr_log("replace reformat(nchw -> nchw4) to "
......@@ -3469,12 +2272,11 @@ void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() {
if (reshape2 == nullptr)
return false;
auto inp_var = rewriter.get_var(reshape2->input(0));
TensorFormat inp_format = is_nchw42nchw32 ? TensorFormat::NCHW4
: TensorFormat::NCHW32,
out_format = is_nchw42nchw32 ? TensorFormat::NCHW32
: TensorFormat::NCHW4;
auto abstract_shuffle =
AbstractShuffleOpr::make(inp_var, inp_format, out_format);
Format inp_format = is_nchw42nchw32 ? Format::NCHW4 : Format::NCHW32,
out_format = is_nchw42nchw32 ? Format::NCHW32 : Format::NCHW4;
auto abstract_shuffle = AbstractShuffleOpr::make(
inp_var, ReformatKey{opr_format_to_tensor_formats(inp_format),
opr_format_to_tensor_formats(out_format)});
std::string reformat_type =
is_nchw42nchw32 ? "nchw4 -> nchw32" : "nchw32 -> nchw4";
rewriter.replace_var(opr->output(0), abstract_shuffle.node(),
......@@ -3507,11 +2309,22 @@ void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() {
param.pattern[2] == 4 && param.pattern[3] == 2 &&
param.pattern[4] == 3 &&
shuffle->input(0)->shape()[4] == 4;
if (!is_nchw42nchw)
bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 &&
param.pattern[2] == 3 && param.pattern[3] == 1 &&
param.pattern[4] == 4 &&
shuffle->input(0)->shape()[4] == 4;
if (!is_nchw42nchw && !is_nchw42nhwc)
return false;
auto inp_var = rewriter.get_var(shuffle->input(0));
auto abstract_shuffle = AbstractShuffleOpr::make(
inp_var, TensorFormat::NCHW4, TensorFormat::NCHW);
ReformatKey key;
key.input_format = TensorFormats::NCHWc4;
if (is_nchw42nchw) {
key.output_format = TensorFormats::NCHW;
} else {
mgb_assert(is_nchw42nhwc);
key.output_format = TensorFormats::NHWC;
}
auto abstract_shuffle = AbstractShuffleOpr::make(inp_var, key);
rewriter.replace_var(
opr->output(0), abstract_shuffle.node(),
mgb_cstr_log("replace reformat(nchw4 -> nchw) to "
......@@ -3532,10 +2345,12 @@ void ShuffleShuffleRemovePass::Impl::detect_shuffle_operations() {
cg::SymbolVar abstract_shuffle;
if (param.mode == opr::RelayoutFormat::Param::Mode::NCHW4_CHWN4) {
abstract_shuffle = AbstractShuffleOpr::make(
inp_var, TensorFormat::NCHW4, TensorFormat::CHWN4);
inp_var,
ReformatKey{TensorFormats::NCHWc4, TensorFormats::CHWNc4});
} else {
abstract_shuffle = AbstractShuffleOpr::make(
inp_var, TensorFormat::CHWN4, TensorFormat::NCHW4);
inp_var,
ReformatKey{TensorFormats::CHWNc4, TensorFormats::NCHWc4});
}
rewriter.replace_var(
opr->output(0), abstract_shuffle.node(),
......@@ -3591,7 +2406,7 @@ void ShuffleShuffleRemovePass::Impl::do_replace() {
}
}
auto on_opr = [this, &rewriter, &uniq_reader_check, &trt_opr_inps,
auto on_opr = [&rewriter, &uniq_reader_check, &trt_opr_inps,
&root](OperatorNodeBase* opr) {
MGB_MARK_USED_VAR(trt_opr_inps);
bool cond_opr = opr->same_type<opr::TypeCvt>() ||
......@@ -3606,17 +2421,17 @@ void ShuffleShuffleRemovePass::Impl::do_replace() {
bool force_folding_typecvt = false;
bool first_shuffle = false;
// initialize inp_format and out_format
TensorFormat out_format = TensorFormat::NCHW,
inp_format = out_format;
TensorFormats out_format = TensorFormats::NCHW,
inp_format = out_format;
megdnn::DType inp_dtype = cur->input(0)->dtype(),
out_dtype = cur->output(0)->dtype();
SmallVector<megdnn::DType> out_dtype_vec;
while (cond_opr) {
if (cur->same_type<AbstractShuffleOpr>()) {
auto shuffle = try_cast_as_op<AbstractShuffleOpr>(cur);
inp_format = shuffle->inp_format();
inp_format = shuffle->key().input_format;
if (!first_shuffle) {
out_format = shuffle->out_format();
out_format = shuffle->key().output_format;
first_shuffle = true;
}
} else {
......@@ -3639,11 +2454,8 @@ void ShuffleShuffleRemovePass::Impl::do_replace() {
#endif
auto new_var = rewriter.get_var(inp_var);
if (inp_format != out_format) {
mgb_assert(m_reformat.find(std::make_pair(
inp_format, out_format)) != m_reformat.end(),
"Unsupported shuffle shuffle remove pass");
new_var = m_reformat[std::make_pair(inp_format, out_format)](
new_var);
new_var = ReformatManager::instance().get(
ReformatKey{inp_format, out_format})({new_var});
}
if (force_folding_typecvt) {
inp_dtype = inp_var->dtype();
......@@ -3683,992 +2495,111 @@ void ShuffleShuffleRemovePass::apply(OptState& opt) const {
MIDOUT_E
}
#if CUDA_VERSION >= 10020
/* ==================== FoldingConvBiasDimshufflePass ================= */
const char* FoldingConvBiasDimshufflePass::name() const {
return mgb_cstr_log("folding conv bias dimshuffle pass");
/* ================ EnableNCHW64Pass =============== */
VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
if (!orig_var->shape().eq_shape(new_var->shape())) {
auto iter = m_opr_format_map.find(new_var->owner_opr());
mgb_assert(iter != m_opr_format_map.end(),
"cannot find opr(type:%s,name:%s) information, related "
"output var node(name:%s)",
new_var->owner_opr()->dyn_typeinfo()->name,
new_var->owner_opr()->cname(), new_var->cname());
const auto& fmt = iter->second;
ReformatKey key;
MGB_TRY {
key.input_format = opr_format_to_tensor_formats(fmt);
key.output_format = TensorFormats::NCHW;
key.input_dtype = new_var->dtype().enumv();
key.output_dtype = new_var->dtype().enumv();
}
MGB_CATCH(AssertionError & err, {
mgb_log_error("%s, related var node(name:%s)", err.what(),
orig_var->cname());
throw;
})
return RelayoutPlaceholder::make(new_var, key).node();
}
return new_var;
}
void FoldingConvBiasDimshufflePass::apply(OptState& opt) const {
MIDOUT_B("FoldingConvBiasDimshufflePass::apply");
using DepType = cg::OperatorNodeProp::DepType;
ThinHashMap<OperatorNodeBase*,
SmallVector<std::pair<OperatorNodeBase*, DepType>>>
readers;
static const ThinHashSet<Typeinfo*> opr_type_list = {
opr::TypeCvt::typeinfo(), opr::Dimshuffle::typeinfo(),
opr::Reshape::typeinfo(), opr::ConvBias::typeinfo()};
opt.graph().iter([&readers](OperatorNodeBase* opr) {
for (auto&& i : opr->node_prop().dep_map()) {
if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) {
readers[i.first->owner_opr()].emplace_back(opr, i.second);
}
std::unique_ptr<EnableNCHW64Pass>
EnableNCHW64Pass::make_nchw64_converter() {
MIDOUT_B("EnableNCHW64Pass::make")
auto ret = std::make_unique<EnableNCHW64Pass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
VarReplaceCheckFlag::CHECK_SHAPE);
auto& replace_func = ret->m_opr_replace_func;
auto& format_map = ret->m_opr_format_map;
auto make_new_conv = [](const VarNodeArray& inps,
const opr::ConvBiasForward* orig_conv,
Format format) {
auto param = orig_conv->param();
// change format
param.format = format;
if (inps.size() == 2) {
auto new_conv = opr::ConvBiasForward::make(
inps[0], inps[1], param, orig_conv->execution_policy(),
orig_conv->config());
return new_conv.node();
} else if (inps.size() == 3) {
auto new_conv = opr::ConvBiasForward::make(
inps[0], inps[1], inps[2], param,
orig_conv->execution_policy(), orig_conv->config());
return new_conv.node();
} else {
mgb_assert(inps.size() == 4);
auto new_conv = opr::ConvBiasForward::make(
inps[0], inps[1], inps[2], inps[3], param,
orig_conv->execution_policy(), orig_conv->config());
return new_conv.node();
}
});
auto rewriter = opt.graph().make_rewriter();
auto nchw42nchw = [](VarNode* inp) -> VarNode* {
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp = opr::Concat::make({sub(0), sub(1) * 4, sub(2), sub(3)}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 1, 4, 2, 3});
auto y1 = opr::Reshape::make(y0, tshp);
auto y2 = opr::TypeCvt::make(y1, dtype::Float32());
return y2.node();
};
auto nchw42nchw32 = [](VarNode* inp) -> VarNode* {
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
auto try_transform_to_nchw =
[&format_map](OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
mgb_assert(opr->input().size() == new_inp.size());
bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::Float32 &&
new_inp[1]->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 3)
check_dtype &= new_inp[2]->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 4)
check_dtype &= new_inp[3]->dtype().enumv() == DTypeEnum::Float32;
if (!check_dtype)
return nullptr;
auto inps = new_inp;
auto process = [&](size_t i) -> VarNode* {
auto iter = format_map.find(new_inp[i]->owner_opr());
if (iter == format_map.end()) {
return inps[i];
} else {
const auto& fmt = iter->second;
ReformatKey key;
key.input_format = opr_format_to_tensor_formats(fmt);
key.output_format = TensorFormats::NCHW;
return RelayoutPlaceholder::make(inps[i], key).node();
}
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1) / 8, cv(8), sub(2), sub(3), sub(4)}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) / 8, sub(2), sub(3), sub(4) * 8}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 4, 2, 5});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
auto nchw322nchw4 = [](VarNode* inp) -> VarNode* {
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 32);
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3), cv(8), sub(4) / 8}, 0),
tshp1 = opr::Concat::make(
{sub(0), sub(1) * 8, sub(2), sub(3), sub(4) / 8}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 4, 2, 3, 5});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2.node();
};
auto nchw42nhwc = [](VarNode* inp) -> VarNode* {
mgb_assert(inp->shape().ndim == 5 && inp->shape()[4] == 4);
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp = opr::Concat::make({sub(0), sub(2), sub(3), sub(1) * 4}, 0);
auto y0 = opr::Dimshuffle::make(x, {0, 2, 3, 1, 4});
auto y1 = opr::Reshape::make(y0, tshp);
return y1.node();
};
auto try_conv_dimshuffle_reshape_typecvt = [&rewriter, &readers,
&nchw42nchw](
OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check typecvt
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr);
if (typecvt == nullptr)
return false;
auto inp_dtype = typecvt->input(0)->dtype(),
out_dtype = typecvt->output(0)->dtype();
bool is_s82f32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype.enumv() == DTypeEnum::Float32;
if (!is_s82f32)
return false;
opr_set.insert(opr);
// check reshape
auto reshape =
try_cast_as_op<opr::Reshape>(typecvt->input(0)->owner_opr());
if (reshape == nullptr)
return false;
opr_set.insert(reshape);
for (auto&& i : readers[reshape]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check shuffle
auto shuffle =
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 5)
return false;
bool is_nchw42nchw = param.pattern[0] == 0 && param.pattern[1] == 1 &&
param.pattern[2] == 4 && param.pattern[3] == 2 &&
param.pattern[4] == 3 &&
shuffle->input(0)->shape()[4] == 4;
if (!is_nchw42nchw)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(shuffle->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW4;
if (!is_s8nchw4)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = nchw42nchw(bias);
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{dtype::Float32()});
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + typecvt + "
"dimshuffle + "
"reshape to conv_bias(NCHW4_NCHW)"));
return true;
};
auto try_conv_reformat_nchw42nchw32 = [&rewriter, &nchw42nchw32,
&readers](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check reshape
auto reshape1 = try_cast_as_op<opr::Reshape>(opr);
if (reshape1 == nullptr)
return false;
opr_set.insert(opr);
// check dimshuffle
auto shuffle = try_cast_as_op<opr::Dimshuffle>(
reshape1->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 6)
return false;
bool is_nchw42nchw32 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
param.pattern[2] == 3 && param.pattern[3] == 4 &&
param.pattern[4] == 2 && param.pattern[5] == 5 &&
shuffle->output(0)->shape()[5] == 4 &&
shuffle->output(0)->shape()[4] == 8;
if (!is_nchw42nchw32)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check reshape
auto reshape2 =
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr());
if (reshape2 == nullptr)
return false;
opr_set.insert(reshape2);
for (auto&& i : readers[reshape2]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW4;
if (!is_s8nchw4)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = nchw42nchw32(bias);
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NCHW32;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
conv_bias->config());
rewriter.replace_var(
opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW4_NCHW32)"));
return true;
};
auto try_conv_reformat_nchw42nhwc = [&rewriter, &nchw42nhwc,
&readers](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check reshape
auto reshape = try_cast_as_op<opr::Reshape>(opr);
if (reshape == nullptr)
return false;
opr_set.insert(opr);
// check dimshuffle
auto shuffle =
try_cast_as_op<opr::Dimshuffle>(reshape->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 5)
return false;
bool is_nchw42nhwc = param.pattern[0] == 0 && param.pattern[1] == 2 &&
param.pattern[2] == 3 && param.pattern[3] == 1 &&
param.pattern[4] == 4 &&
shuffle->output(0)->shape()[4] == 4;
if (!is_nchw42nhwc)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
auto typecvt =
try_cast_as_op<opr::TypeCvt>(shuffle->input(0)->owner_opr());
if (typecvt == nullptr)
return false;
auto in_dtype = typecvt->input(0)->dtype(),
out_dtype = typecvt->output(0)->dtype();
bool is_s82s4 = in_dtype.enumv() == DTypeEnum::QuantizedS8 &&
(out_dtype.enumv() == DTypeEnum::QuantizedS4 ||
out_dtype.enumv() == DTypeEnum::Quantized4Asymm);
if (!is_s82s4)
return false;
opr_set.insert(typecvt);
for (auto&& i : readers[typecvt]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw4 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW4;
if (!is_s8nchw4)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = nchw42nhwc(bias);
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW4_NHWC;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype});
rewriter.replace_var(opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW4_NHWC)"));
return true;
};
auto try_conv_reformat_nchw322nchw4 = [&rewriter, &readers, &nchw322nchw4](
OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check reshape
auto reshape1 = try_cast_as_op<opr::Reshape>(opr);
if (reshape1 == nullptr)
return false;
opr_set.insert(opr);
// check dimshuffle
auto shuffle = try_cast_as_op<opr::Dimshuffle>(
reshape1->input(0)->owner_opr());
if (shuffle == nullptr)
return false;
auto&& param = shuffle->param();
if (param.pattern_len != 6)
return false;
bool is_nchw322nchw4 = param.pattern[0] == 0 && param.pattern[1] == 1 &&
param.pattern[2] == 4 && param.pattern[3] == 2 &&
param.pattern[4] == 3 && param.pattern[5] == 5 &&
shuffle->input(0)->shape()[5] == 4 &&
shuffle->input(0)->shape()[4] == 8;
if (!is_nchw322nchw4)
return false;
opr_set.insert(shuffle);
for (auto&& i : readers[shuffle]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check reshape
auto reshape2 =
try_cast_as_op<opr::Reshape>(shuffle->input(0)->owner_opr());
if (reshape2 == nullptr)
return false;
opr_set.insert(reshape2);
for (auto&& i : readers[reshape2]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(reshape2->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
auto inp_dtype = conv_bias->input(0)->dtype();
bool is_s8nchw32 = inp_dtype.enumv() == DTypeEnum::QuantizedS8 &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NCHW32;
if (!is_s8nchw32)
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias = nchw322nchw4(bias);
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NCHW32_NCHW4;
auto conv_bias_shuffle = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
conv_bias->config());
rewriter.replace_var(
opr->output(0), conv_bias_shuffle.node(),
mgb_cstr_log("replace conv_bias + "
"reformat to conv_bias(NCHW32_NCHW4)"));
return true;
};
MGB_MARK_USED_VAR(try_conv_reformat_nchw322nchw4);
MGB_MARK_USED_VAR(try_conv_reformat_nchw42nchw32);
auto on_opr = [&try_conv_dimshuffle_reshape_typecvt,
&try_conv_reformat_nchw42nchw32,
&try_conv_reformat_nchw42nhwc,
&try_conv_reformat_nchw322nchw4,
&rewriter](OperatorNodeBase* opr) {
if (!try_conv_dimshuffle_reshape_typecvt(opr) &&
!try_conv_reformat_nchw42nchw32(opr) &&
!try_conv_reformat_nchw42nhwc(opr) &&
!try_conv_reformat_nchw322nchw4(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
MIDOUT_E
}
#endif
/* ==================== PaddingChannelPass ================= */
const char* PaddingChannelPass::name() const {
return mgb_cstr_log("padding output channel to multiple of 4/32");
}
void PaddingChannelPass::apply(OptState& opt) const {
MIDOUT_B("PaddingChannelPass::apply");
// do not check shape
opt.set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
VarReplaceCheckFlag::CHECK_SHAPE);
ThinHashSet<OperatorNodeBase*> padding_oprs;
ThinHashMap<Typeinfo*, thin_function<OperatorNodeBase*(
OperatorNodeBase*, const VarNodeArray&)>>
opr_replace_funcs;
auto rewriter = opt.graph().make_rewriter();
auto pad_in_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 ||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 ||
inp->dtype().enumv() == DTypeEnum::QuantizedS32);
TensorShape shape{inp->shape()[0], pad_channels, inp->shape()[2],
inp->shape()[3]};
std::shared_ptr<HostTensorND> host_val =
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
host_val->resize(shape);
auto ptr = host_val->raw_ptr();
size_t size_bytes =
TensorLayout{shape, inp->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding =
opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
auto out = opr::Concat::make({inp, padding}, 1);
return out.node();
};
auto pad_out_channels = [](VarNode* inp, size_t pad_channels) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(inp->dtype().enumv() == DTypeEnum::QuantizedS4 ||
inp->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
inp->dtype().enumv() == DTypeEnum::QuantizedS8 ||
inp->dtype().enumv() == DTypeEnum::QuantizedS32);
TensorShape shape{pad_channels, inp->shape()[1], inp->shape()[2],
inp->shape()[3]};
std::shared_ptr<HostTensorND> host_val =
std::make_shared<HostTensorND>(inp->comp_node(), inp->dtype());
host_val->resize(shape);
auto ptr = host_val->raw_ptr();
size_t size_bytes =
TensorLayout{shape, inp->dtype()}.span().dist_byte();
std::memset(ptr, 0, size_bytes);
auto padding =
opr::ImmutableTensor::make(*inp->owner_graph(), *host_val);
auto out = opr::Concat::make({inp, padding}, 0);
return out.node();
};
auto extract_subtensor = [](VarNode* inp,
const TensorShape& orig_shape) -> VarNode* {
mgb_assert(inp->shape().ndim == 4);
mgb_assert(inp->shape()[0] == orig_shape[0]);
mgb_assert(inp->shape()[2] == orig_shape[2]);
mgb_assert(inp->shape()[3] == orig_shape[3]);
size_t orig_channels = orig_shape[1];
auto x = SymbolVar(inp);
auto cv = [&x](int v) { return x.make_scalar(v); };
using AIdx = opr::Subtensor::AxisIndexer;
auto sub = opr::Subtensor::make(
x, {AIdx::make_interval(0, None, None, cv(1)),
AIdx::make_interval(1, None, cv(orig_channels), None),
AIdx::make_interval(2, None, None, cv(1)),
AIdx::make_interval(3, None, None, cv(1))});
return sub.node();
};
// padding policy for conv bias with data type qint8
auto padding_policy_qint8 = [&padding_oprs, &pad_in_channels,
&pad_out_channels](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 3);
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
size_t new_in_channels = new_inp[0]->shape()[1];
// pad input channels
if (padding_oprs.count(opr->input(0)->owner_opr())) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels = 0;
mgb_assert(new_in_channels == in_channels);
if (in_channels <= 16) {
if (in_channels % 4)
pad_channels = 4 - (in_channels % 4); // pad to use dp4a
} else {
if (in_channels % 32)
pad_channels =
32 - (in_channels % 32); // pad to use tensorcore
}
if (pad_channels > 0) {
inps[0] = pad_in_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
}
out_channels = inps[1]->shape()[0];
in_channels = inps[1]->shape()[1];
size_t pad_channels = 0;
if (out_channels <= 16) {
if (out_channels % 4)
pad_channels = 4 - (out_channels % 4);
} else {
if (out_channels % 32)
pad_channels = 32 - (out_channels % 32);
}
if (pad_channels > 0) {
inps[1] = pad_out_channels(inps[1], pad_channels);
inps[2] = pad_in_channels(inps[2], pad_channels);
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
// padding policy for conv bias with data type qint4 and quint4
auto padding_policy_int4 = [&padding_oprs, &pad_in_channels,
&pad_out_channels](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 3);
mgb_assert(opr->input(1)->shape().eq_shape(new_inp[1]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(1)->shape()[0];
size_t in_channels = opr->input(1)->shape()[1];
size_t new_in_channels = new_inp[0]->shape()[1];
// pad input channels
if (padding_oprs.count(opr->input(0)->owner_opr())) {
if (new_in_channels <= 32) {
if (new_in_channels % 8 == 0) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels_0 = 8 - (new_in_channels % 8);
size_t pad_channels_1 = 8 - (in_channels % 8);
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
}
} else {
if (new_in_channels % 64 == 0) {
size_t pad_channels = new_in_channels - in_channels;
inps[1] = pad_in_channels(new_inp[1], pad_channels);
} else {
size_t pad_channels_0 = 64 - (new_in_channels % 64);
size_t pad_channels_1 = 64 - (in_channels % 64);
inps[0] = pad_in_channels(new_inp[0], pad_channels_0);
inps[1] = pad_in_channels(new_inp[1], pad_channels_1);
}
}
} else {
size_t pad_channels = 0;
mgb_assert(new_in_channels == in_channels);
if (in_channels <= 32) {
if (in_channels % 8)
pad_channels = 8 - (in_channels % 8);
} else {
if (in_channels % 64)
pad_channels = 64 - (in_channels % 64);
}
if (pad_channels > 0) {
inps[0] = pad_in_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
}
out_channels = inps[1]->shape()[0];
in_channels = inps[1]->shape()[1];
size_t pad_channels = 0;
if (out_channels <= 32) {
if (out_channels % 8)
pad_channels = 8 - (out_channels % 8);
} else {
if (out_channels % 64)
pad_channels = 64 - (out_channels % 64);
}
if (pad_channels > 0) {
inps[1] = pad_out_channels(inps[1], pad_channels);
inps[2] = pad_in_channels(inps[2], pad_channels);
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
opr_replace_funcs[opr::ConvBiasForward::typeinfo()] =
[&padding_oprs, &padding_policy_qint8, &padding_policy_int4](
OperatorNodeBase* opr, const VarNodeArray& new_inp) {
if (opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS8) {
return padding_policy_qint8(opr, new_inp);
} else if (opr->input(0)->dtype().enumv() ==
DTypeEnum::QuantizedS4 ||
opr->input(0)->dtype().enumv() ==
DTypeEnum::Quantized4Asymm) {
return padding_policy_int4(opr, new_inp);
} else {
mgb_assert(
padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"conv bias operator for data type(%s) cannot be "
"padded channel. "
"consumer(%s), producer(%s)",
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
};
opr_replace_funcs[opr::ConvolutionBackwardData::typeinfo()] =
[&padding_oprs, &pad_in_channels, &pad_out_channels](
OperatorNodeBase* opr, const VarNodeArray& new_inp) {
if (opr->input(1)->dtype().enumv() != DTypeEnum::QuantizedS8) {
mgb_assert(
padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"conv bwd data operator for data type(%s) cannot "
"be "
"padded channel. "
"consumer(%s), producer(%s)",
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(new_inp.size() == 2,
"deconv (conv bwd data) operator for inference can "
"only have 2 input vars(got:%zu)",
new_inp.size());
mgb_assert(
opr->input(0)->shape().eq_shape(new_inp[0]->shape()));
auto inps = new_inp;
size_t out_channels = opr->input(0)->shape()[0];
size_t in_channels = opr->input(0)->shape()[1];
size_t new_out_channels = new_inp[1]->shape()[1];
// pad output channels
if (padding_oprs.count(opr->input(1)->owner_opr())) {
size_t pad_channels = new_out_channels - out_channels;
inps[0] = pad_out_channels(new_inp[0], pad_channels);
} else {
size_t pad_channels = 0;
if (out_channels % 4)
pad_channels = 4 - (out_channels % 4);
if (pad_channels > 0) {
inps[0] = pad_out_channels(new_inp[0], pad_channels);
inps[1] = pad_in_channels(new_inp[1], pad_channels);
}
}
out_channels = inps[0]->shape()[0];
in_channels = inps[0]->shape()[1];
// pad input channels
size_t pad_channels = 0;
if (in_channels % 4)
pad_channels = 4 - (in_channels % 4);
if (pad_channels > 0) {
inps[0] = pad_in_channels(inps[0], pad_channels);
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, inps,
opr->config());
};
auto replace_format_aware_opr = [&padding_oprs](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
if (opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS8 &&
opr->input(0)->dtype().enumv() != DTypeEnum::QuantizedS4 &&
opr->input(0)->dtype().enumv() != DTypeEnum::Quantized4Asymm) {
mgb_assert(padding_oprs.count(opr->input(0)->owner_opr()) == 0,
"operator(type:%s,name:%s) for data type(%s) cannot be "
"padded channel. extra info:"
"consumer(%s), producer(%s)",
opr->dyn_typeinfo()->name, opr->cname(),
opr->input(0)->dtype().name(), opr->cname(),
opr->input(0)->owner_opr()->cname());
return serialization::copy_opr_shallow(*opr, new_inp,
opr->config());
}
mgb_assert(opr->input().size() == new_inp.size());
if (padding_oprs.count(opr->input(0)->owner_opr())) {
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
opr_replace_funcs[opr::PoolingForward::typeinfo()] =
replace_format_aware_opr;
opr_replace_funcs[opr::WarpPerspectiveForward::typeinfo()] =
replace_format_aware_opr;
auto replace_elemwise_like_opr = [&padding_oprs, &extract_subtensor](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
bool have_padding_inp = false;
bool padding_all_inps = true;
bool same_padding = true;
size_t channels_after_padding = 0;
size_t i = 0;
for (auto&& cur_inp : opr->input()) {
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
if (!have_padding_inp)
have_padding_inp = true;
if (channels_after_padding == 0) {
channels_after_padding = new_inp[i]->shape()[1];
} else {
same_padding =
channels_after_padding == new_inp[i]->shape()[1];
}
}
if (padding_all_inps && (!padding_cur_inp || !same_padding))
padding_all_inps = false;
++i;
}
if (have_padding_inp && !padding_all_inps) {
auto inps = new_inp;
for (size_t i = 0; i < new_inp.size(); ++i) {
auto cur_inp = opr->input(i);
bool padding_cur_inp =
padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
inps[i] = extract_subtensor(inps[i], cur_inp->shape());
}
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
}
if (padding_all_inps) {
padding_oprs.insert(opr);
}
return serialization::copy_opr_shallow(*opr, new_inp, opr->config());
};
opr_replace_funcs[opr::ElemwiseMultiType::typeinfo()] =
replace_elemwise_like_opr;
opr_replace_funcs[opr::Elemwise::typeinfo()] = replace_elemwise_like_opr;
opr_replace_funcs[opr::TypeCvt::typeinfo()] = replace_elemwise_like_opr;
auto replace_nonpadding_oprs = [&padding_oprs, &extract_subtensor](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto inps = new_inp;
for (size_t i = 0; i < new_inp.size(); ++i) {
auto cur_inp = opr->input(i);
bool padding_cur_inp = padding_oprs.count(cur_inp->owner_opr()) > 0;
if (padding_cur_inp) {
inps[i] = extract_subtensor(inps[i], cur_inp->shape());
}
}
return serialization::copy_opr_shallow(*opr, inps, opr->config());
};
opr_replace_funcs[opr::Reshape::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::GetVarShape::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Concat::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Reduce::typeinfo()] = replace_nonpadding_oprs;
opr_replace_funcs[opr::Subtensor::typeinfo()] = replace_nonpadding_oprs;
auto on_opr = [&opt, &rewriter, &opr_replace_funcs,
&extract_subtensor](OperatorNodeBase* opr) {
auto it = opr_replace_funcs.find(opr->dyn_typeinfo());
if (it != opr_replace_funcs.end()) {
VarNodeArray new_inp;
new_inp.reserve(opr->input().size());
for (auto&& inp : opr->input()) {
new_inp.push_back(rewriter.get_var(inp));
}
auto new_opr = (it->second)(opr, new_inp);
auto &&out0 = opr->output(), &&out1 = new_opr->output();
mgb_assert(out0.size() == out1.size(),
"bad opr replace: src=%s{%s} dst=%s{%s}, "
"src.size=%zu "
"dst.size=%zu",
opr->cname(), opr->dyn_typeinfo()->name,
new_opr->cname(), new_opr->dyn_typeinfo()->name,
out0.size(), out1.size());
for (size_t i = 0; i < out0.size(); ++i) {
if (!out0[i]->contain_flag(VarNode::Flag::VOLATILE_CONTENT)) {
mgb_assert(!out1[i]->contain_flag(
VarNode::Flag::VOLATILE_CONTENT));
auto src = out0[i];
auto dst = out1[i];
if (opt.graph().endpoint_contain(src) &&
!src->shape().eq_shape(dst->shape())) {
dst = extract_subtensor(dst, src->shape());
}
rewriter.replace_var(src, dst, nullptr);
}
}
} else {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
MIDOUT_E
}
/* ================ EnableNCHW64Pass =============== */
VarNode* EnableNCHW64Pass::on_graph_endpoint_var(VarNode* new_var,
VarNode* orig_var) const {
if (!orig_var->shape().eq_shape(new_var->shape())) {
auto iter = m_opr_format_map.find(new_var->owner_opr());
mgb_assert(iter != m_opr_format_map.end(),
"cannot find opr(type:%s,name:%s) information, related "
"output var node(name:%s)",
new_var->owner_opr()->dyn_typeinfo()->name,
new_var->owner_opr()->cname(), new_var->cname());
const auto& fmt = iter->second;
using LayoutType = RelayoutPlaceholder::LayoutType;
LayoutType type;
switch (fmt) {
case Format::NCHW4:
type = LayoutType::NCHW4_TO_NCHW;
break;
case Format::NCHW32:
type = LayoutType::NCHW32_TO_NCHW;
break;
case Format::NCHW64:
type = LayoutType::NCHW64_TO_NCHW;
break;
case Format::NHWC:
type = LayoutType::NHWC_TO_NCHW;
break;
default:
mgb_throw(AssertionError,
"format(%d) is not supported, related var "
"node(name:%s)",
static_cast<int>(fmt), orig_var->cname());
};
return RelayoutPlaceholder::make(new_var, type).node();
}
return new_var;
}
std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
MIDOUT_B("EnableNCHW64Pass::make")
auto ret = std::make_unique<EnableNCHW64Pass>();
ret->set_var_replace_check_flag(VarReplaceCheckFlag::CHECK_ALL ^
VarReplaceCheckFlag::CHECK_SHAPE);
auto& replace_func = ret->m_opr_replace_func;
auto& format_map = ret->m_opr_format_map;
auto make_new_conv = [](const VarNodeArray& inps,
const opr::ConvBiasForward* orig_conv,
Format format) {
auto param = orig_conv->param();
// change format
param.format = format;
if (inps.size() == 2) {
auto new_conv = opr::ConvBiasForward::make(
inps[0], inps[1], param, orig_conv->execution_policy(),
orig_conv->config());
return new_conv.node();
} else if (inps.size() == 3) {
auto new_conv = opr::ConvBiasForward::make(
inps[0], inps[1], inps[2], param,
orig_conv->execution_policy(), orig_conv->config());
return new_conv.node();
} else {
mgb_assert(inps.size() == 4);
auto new_conv = opr::ConvBiasForward::make(
inps[0], inps[1], inps[2], inps[3], param,
orig_conv->execution_policy(), orig_conv->config());
return new_conv.node();
}
};
auto try_transform_to_nchw =
[&format_map](OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
mgb_assert(opr->input().size() == new_inp.size());
bool check_dtype = new_inp[0]->dtype().enumv() == DTypeEnum::Float32 &&
new_inp[1]->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 3)
check_dtype &= new_inp[2]->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 4)
check_dtype &= new_inp[3]->dtype().enumv() == DTypeEnum::Float32;
if (!check_dtype)
return nullptr;
auto inps = new_inp;
auto process = [&](size_t i) -> VarNode* {
auto iter = format_map.find(new_inp[i]->owner_opr());
if (iter == format_map.end()) {
return inps[i];
} else {
const auto& fmt = iter->second;
if (fmt == Format::NCHW32) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW);
return ovar.node();
} else if (fmt == Format::NCHW4) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW);
return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW);
return ovar.node();
}
}
};
for (size_t i = 0; i < inps.size(); ++i) {
inps[i] = process(i);
}
auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config());
return ret->output()[0];
for (size_t i = 0; i < inps.size(); ++i) {
inps[i] = process(i);
}
auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config());
return ret->output()[0];
};
auto try_transform_to_nchw4 =
[make_new_conv, &format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(opr->input().size()==new_inp.size());
bool check_dtype =
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8;
mgb_assert(opr->output().size() > 0);
bool dst_float = opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
if (opr->input().size() >= 3) {
auto dtype_expect =
dst_float ? DTypeEnum::Float32 : DTypeEnum::QuantizedS32;
auto dtype_expect = dst_float ? DTypeEnum::Float32
: DTypeEnum::QuantizedS32;
check_dtype &= new_inp[2]->dtype().enumv() == dtype_expect;
}
if (opr->input().size() >= 4) {
......@@ -4689,29 +2620,18 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
auto iter = format_map.find(new_inp[i]->owner_opr());
if (iter == format_map.end()) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4);
inps[i], ReformatKey{TensorFormats::NCHW,
TensorFormats::NCHWc4});
return ovar.node();
} else {
const auto& fmt = iter->second;
if (fmt == Format::NCHW4) {
return inps[i];
} else if (fmt == Format::NCHW32) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW4);
return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW4);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW4);
return ovar.node();
ReformatKey key;
key.input_format = opr_format_to_tensor_formats(fmt);
key.output_format = TensorFormats::NCHWc4;
return RelayoutPlaceholder::make(inps[i], key).node();
}
}
};
......@@ -4719,13 +2639,12 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
for (size_t i = 0; i < inps.size(); ++i) {
// do not format bias and z when dst_float is true
bool skip = dst_float && i >= 2;
if (!skip)
inps[i] = process(i);
if (!skip) inps[i] = process(i);
}
auto& conv_bias = opr->cast_final_safe<opr::ConvBiasForward>();
auto ret =
make_new_conv(inps, &conv_bias,
dst_float ? Format::NCHW4_NCHW : Format::NCHW4);
auto ret = make_new_conv(
inps, &conv_bias,
dst_float ? Format::NCHW4_NCHW : Format::NCHW4);
if (!dst_float)
format_map.insert(std::make_pair(ret->owner_opr(), Format::NCHW4));
return ret;
......@@ -4735,7 +2654,7 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
[make_new_conv, &format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(opr->input().size()==new_inp.size());
bool check_dtype =
new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS8 &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS8;
......@@ -4755,31 +2674,18 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
auto inps = new_inp;
auto process = [&](size_t i) -> VarNode* {
auto iter = format_map.find(new_inp[i]->owner_opr());
ReformatKey key;
key.output_format = TensorFormats::NCHWc32;
if (iter == format_map.end()) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW32);
return ovar.node();
key.input_format = TensorFormats::NCHW;
return RelayoutPlaceholder::make(inps[i], key).node();
} else {
const auto& fmt = iter->second;
if (fmt == Format::NCHW32) {
return inps[i];
} else if (fmt == Format::NCHW4) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW32);
return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW32);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW64_TO_NCHW32);
return ovar.node();
key.input_format = opr_format_to_tensor_formats(fmt);
return RelayoutPlaceholder::make(inps[i], key).node();
}
}
};
......@@ -4797,17 +2703,18 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
// fint4XWint4 and fuint4XWint4
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(opr->input().size()==new_inp.size());
bool check_dtype =
(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 ||
new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) &&
new_inp[0]->dtype().enumv() ==
DTypeEnum::Quantized4Asymm) &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4;
if (opr->input().size() >= 3)
check_dtype &=
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32;
if (opr->input().size() >= 4)
check_dtype &=
new_inp[3]->dtype().enumv() == new_inp[0]->dtype().enumv();
check_dtype &= new_inp[3]->dtype().enumv() ==
new_inp[0]->dtype().enumv();
if (!check_dtype)
return nullptr;
size_t out_channels = opr->input(1)->shape()[0];
......@@ -4818,31 +2725,20 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
auto inps = new_inp;
auto process = [&](size_t i) -> VarNode* {
auto iter = format_map.find(new_inp[i]->owner_opr());
ReformatKey key;
key.output_format = TensorFormats::NCHWc64;
if (iter == format_map.end()) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW64);
return ovar.node();
key.input_format = TensorFormats::NCHW;
key.input_dtype = key.output_dtype = inps[i]->dtype().enumv();
return RelayoutPlaceholder::make(inps[i], key).node();
} else {
const auto& fmt = iter->second;
if (fmt == Format::NCHW64) {
return inps[i];
} else if (fmt == Format::NCHW4) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NCHW64);
return ovar.node();
} else if (fmt == Format::NHWC) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NHWC_TO_NCHW64);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW32);
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW32_TO_NCHW64);
return ovar.node();
key.input_format = opr_format_to_tensor_formats(fmt);
key.input_dtype = key.output_dtype = inps[i]->dtype().enumv();
return RelayoutPlaceholder::make(inps[i], key).node();
}
}
};
......@@ -4860,17 +2756,18 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
OperatorNodeBase* opr,
const VarNodeArray& new_inp) -> VarNode* {
// fint4XWint4 and fuint4XWint4
mgb_assert(opr->input().size() == new_inp.size());
mgb_assert(opr->input().size()==new_inp.size());
bool check_dtype =
(new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 ||
new_inp[0]->dtype().enumv() == DTypeEnum::Quantized4Asymm) &&
new_inp[0]->dtype().enumv() ==
DTypeEnum::Quantized4Asymm) &&
new_inp[1]->dtype().enumv() == DTypeEnum::QuantizedS4;
if (opr->input().size() >= 3)
check_dtype &=
new_inp[2]->dtype().enumv() == DTypeEnum::QuantizedS32;
if (opr->input().size() >= 4)
check_dtype &=
new_inp[3]->dtype().enumv() == new_inp[0]->dtype().enumv();
check_dtype &= new_inp[3]->dtype().enumv() ==
new_inp[0]->dtype().enumv();
if (!check_dtype)
return nullptr;
size_t out_channels = opr->input(1)->shape()[0];
......@@ -4881,30 +2778,19 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
auto inps = new_inp;
auto process = [&](size_t i) -> VarNode* {
auto iter = format_map.find(new_inp[i]->owner_opr());
ReformatKey key;
key.output_format = TensorFormats::NHWC;
key.input_dtype = key.output_dtype = inps[i]->dtype().enumv();
if (iter == format_map.end()) {
auto ovar = RelayoutPlaceholder::make(
inps[i], RelayoutPlaceholder::LayoutType::NCHW_TO_NHWC);
return ovar.node();
key.input_format = TensorFormats::NCHW;
return RelayoutPlaceholder::make(inps[i], key).node();
} else {
const auto& fmt = iter->second;
if (fmt == Format::NHWC) {
return inps[i];
} else if (fmt == Format::NCHW4) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW4_TO_NHWC);
return ovar.node();
} else if (fmt == Format::NCHW32) {
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW32_TO_NHWC);
return ovar.node();
} else {
mgb_assert(fmt == Format::NCHW64);
auto ovar = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::NCHW64_TO_NHWC);
return ovar.node();
key.input_format = opr_format_to_tensor_formats(fmt);
return RelayoutPlaceholder::make(inps[i], key).node();
}
}
};
......@@ -4982,38 +2868,17 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
}
auto inps = new_inp;
inps[0] = RelayoutPlaceholder::make(
inps[0],
RelayoutPlaceholder::LayoutType::NCHW_TO_NCHW4)
inps[0], ReformatKey{TensorFormats::NCHW,
TensorFormats::NCHWc4})
.node();
switch (cur) {
case Format::NCHW:
inps[1] = RelayoutPlaceholder::make(
inps[1], RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW4)
.node();
break;
case Format::NHWC:
inps[1] = RelayoutPlaceholder::make(
inps[1], RelayoutPlaceholder::LayoutType::
NHWC_TO_NCHW4)
.node();
break;
case Format::NCHW32:
inps[1] = RelayoutPlaceholder::make(
inps[1], RelayoutPlaceholder::LayoutType::
NCHW32_TO_NCHW4)
.node();
break;
case Format::NCHW64:
inps[1] = RelayoutPlaceholder::make(
inps[1], RelayoutPlaceholder::LayoutType::
NCHW64_TO_NCHW4)
.node();
break;
default:
mgb_assert(cur == Format::NCHW4);
}
if (cur != Format::NCHW4) {
inps[1] = RelayoutPlaceholder::make(
inps[1],
ReformatKey{opr_format_to_tensor_formats(cur),
TensorFormats::NCHWc4})
.node();
}
auto param = deconv.param();
param.format = Format::NCHW4;
auto new_deconv = opr::ConvolutionBackwardData::make(
......@@ -5030,7 +2895,7 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
break;
}
}
mgb_assert(!shape_changed,
mgb_assert(!shape_changed,
"EnableNCHW64Pass won't change format of output tensor "
"of non quantized deconv operator(name:%s)",
opr->cname());
......@@ -5040,9 +2905,8 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
};
// replace rule for elemwise like opr
auto replace_elemwise_like_opr = [&format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
auto replace_elemwise_like_opr = [&format_map](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
ThinHashMap<Format, size_t> format_size;
bool same_format = true;
......@@ -5082,28 +2946,6 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
max_size = item.second;
}
}
static const ThinHashMap<std::pair<Format, Format>,
thin_function<VarNode*(VarNode*)>>
map = {
#define cb(_fmt1, _fmt2) \
{ \
std::make_pair(Format::_fmt1, Format::_fmt2), \
[](VarNode* in) -> VarNode* { \
return RelayoutPlaceholder::make( \
in, RelayoutPlaceholder::LayoutType:: \
_fmt1##_TO_##_fmt2) \
.node(); \
} \
}
cb(NCHW, NCHW4), cb(NCHW, NCHW32), cb(NCHW, NCHW64),
cb(NCHW4, NCHW), cb(NCHW4, NCHW32), cb(NCHW4, NCHW64),
cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64),
cb(NCHW32, NCHW), cb(NCHW32, NCHW4), cb(NCHW32, NCHW64),
cb(NCHW, NHWC), cb(NCHW4, NHWC), cb(NCHW32, NHWC),
cb(NCHW64, NHWC), cb(NHWC, NCHW), cb(NHWC, NCHW4),
cb(NHWC, NCHW32), cb(NHWC, NCHW64),
#undef cb
};
auto inps = new_inp;
for (size_t i = 0; i < opr->input().size(); ++i) {
auto iter = format_map.find(new_inp[i]->owner_opr());
......@@ -5114,7 +2956,10 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
cur = Format::NCHW;
}
if (cur != max_format) {
inps[i] = map.at(std::make_pair(cur, max_format))(inps[i]);
ReformatKey key{opr_format_to_tensor_formats(cur),
opr_format_to_tensor_formats(max_format)};
key.input_dtype = key.output_dtype = inps[i]->dtype().enumv();
inps[i] = RelayoutPlaceholder::make(inps[i], key).node();
}
}
auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config());
......@@ -5144,27 +2989,11 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
cur = iter->second;
}
auto inps = new_inp;
switch (cur) {
case Format::NCHW:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW_TO_NHWC)
.node();
break;
case Format::NCHW4:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW4_TO_NHWC)
.node();
break;
case Format::NCHW32:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW32_TO_NHWC)
.node();
break;
default:
mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC);
if (cur != Format::NCHW64 && cur != Format::NHWC) {
ReformatKey key{opr_format_to_tensor_formats(cur),
TensorFormats::NHWC, inps[0]->dtype().enumv(),
inps[0]->dtype().enumv()};
inps[0] = RelayoutPlaceholder::make(inps[0], key).node();
}
auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC;
auto param = warp.param();
......@@ -5172,7 +3001,8 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
SymbolVar new_warp;
if (inps.size() == 3) {
new_warp = opr::WarpPerspectiveForward::make(
inps[0], inps[1], inps[2], param, warp.config());
inps[0], inps[1], inps[2], param,
warp.config());
} else {
mgb_assert(inps.size() == 4);
new_warp = opr::WarpPerspectiveForward::make(
......@@ -5191,41 +3021,20 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
cur = iter->second;
}
auto inps = new_inp;
switch (cur) {
case Format::NCHW:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW_TO_NCHW4)
.node();
break;
case Format::NHWC:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NHWC_TO_NCHW4)
.node();
break;
case Format::NCHW32:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW32_TO_NCHW4)
.node();
break;
case Format::NCHW64:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW64_TO_NCHW4)
.node();
break;
default:
mgb_assert(cur == Format::NCHW4);
if (cur != Format::NCHW4) {
inps[0] = RelayoutPlaceholder::make(
inps[0],
ReformatKey{opr_format_to_tensor_formats(cur),
TensorFormats::NCHWc4})
.node();
}
auto param = warp.param();
param.format = Format::NCHW4;
SymbolVar new_warp;
if (inps.size() == 3) {
new_warp = opr::WarpPerspectiveForward::make(
inps[0], inps[1], inps[2], param, warp.config());
inps[0], inps[1], inps[2], param,
warp.config());
} else {
mgb_assert(inps.size() == 4);
new_warp = opr::WarpPerspectiveForward::make(
......@@ -5243,7 +3052,7 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
break;
}
}
mgb_assert(!shape_changed,
mgb_assert(!shape_changed,
"EnableNCHW64Pass won't change format of output tensor "
"of non quantized warp perspective operator(name:%s)",
opr->cname());
......@@ -5251,8 +3060,9 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
opr->config());
}
};
auto replace_pooling_opr = [&format_map](OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
auto replace_pooling_opr = [&format_map](
OperatorNodeBase* opr,
const VarNodeArray& new_inp) {
mgb_assert(opr->input().size() == new_inp.size());
auto& pooling = opr->cast_final_safe<opr::PoolingForward>();
if (new_inp[0]->dtype().enumv() == DTypeEnum::QuantizedS4 ||
......@@ -5265,27 +3075,11 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
cur = iter->second;
}
auto inps = new_inp;
switch (cur) {
case Format::NCHW:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW_TO_NHWC)
.node();
break;
case Format::NCHW4:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW4_TO_NHWC)
.node();
break;
case Format::NCHW32:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW32_TO_NHWC)
.node();
break;
default:
mgb_assert(cur == Format::NCHW64 || cur == Format::NHWC);
if (cur != Format::NCHW64 && cur != Format::NHWC) {
ReformatKey key{opr_format_to_tensor_formats(cur),
TensorFormats::NHWC, inps[0]->dtype().enumv(),
inps[0]->dtype().enumv()};
inps[0] = RelayoutPlaceholder::make(inps[0], key).node();
}
auto target_format = cur == Format::NCHW64 ? cur : Format::NHWC;
auto param = pooling.param();
......@@ -5305,30 +3099,31 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
}
bool use_nchw32 = false;
auto inps = new_inp;
using LayoutType = RelayoutPlaceholder::LayoutType;
ReformatKey key;
switch (cur) {
case Format::NCHW: {
size_t in_channels = new_inp[0]->shape()[1];
use_nchw32 = in_channels % 32 == 0;
auto layout_type = use_nchw32 ? LayoutType::NCHW_TO_NCHW32
: LayoutType::NCHW_TO_NCHW4;
inps[0] = RelayoutPlaceholder::make(inps[0], layout_type)
.node();
key.input_format = TensorFormats::NCHW;
key.output_format = use_nchw32 ? TensorFormats::NCHWc32
: TensorFormats::NCHWc4;
inps[0] = RelayoutPlaceholder::make(inps[0], key).node();
break;
}
case Format::NHWC: {
size_t in_channels = new_inp[0]->shape()[3];
use_nchw32 = in_channels % 32 == 0;
auto layout_type = use_nchw32 ? LayoutType::NHWC_TO_NCHW32
: LayoutType::NHWC_TO_NCHW4;
inps[0] = RelayoutPlaceholder::make(inps[0], layout_type)
.node();
key.input_format = TensorFormats::NHWC;
key.output_format = use_nchw32 ? TensorFormats::NCHWc32
: TensorFormats::NCHWc4;
inps[0] = RelayoutPlaceholder::make(inps[0], key).node();
break;
}
case Format::NCHW64:
inps[0] = RelayoutPlaceholder::make(
inps[0], RelayoutPlaceholder::LayoutType::
NCHW64_TO_NCHW32)
inps[0],
ReformatKey{TensorFormats::NCHWc64,
TensorFormats::NCHWc32})
.node();
break;
case Format::NCHW32:
......@@ -5338,7 +3133,7 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
mgb_assert(cur == Format::NCHW4);
}
Format out_format = use_nchw32 ? Format::NCHW32 : Format::NCHW4;
auto param = pooling.param();
param.format = out_format;
auto new_pool =
......@@ -5374,39 +3169,12 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
auto inps = new_inp;
for (size_t i = 0; i < opr->input().size(); ++i) {
auto iter = format_map.find(new_inp[i]->owner_opr());
auto fmt = iter != format_map.end() ? iter->second : Format::NCHW;
auto fmt = iter != format_map.end()?iter->second:Format::NCHW;
if (iter != format_map.end()) {
switch (fmt) {
case Format::NHWC:
inps[i] = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::
NHWC_TO_NCHW)
.node();
break;
case Format::NCHW4:
inps[i] = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::
NCHW4_TO_NCHW)
.node();
break;
case Format::NCHW32:
inps[i] = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::
NCHW32_TO_NCHW)
.node();
break;
default:
mgb_assert(fmt == Format::NCHW64);
inps[i] = RelayoutPlaceholder::make(
inps[i],
RelayoutPlaceholder::LayoutType::
NCHW64_TO_NCHW)
.node();
break;
}
ReformatKey key{opr_format_to_tensor_formats(fmt),
TensorFormats::NCHW, inps[i]->dtype().enumv(),
inps[i]->dtype().enumv()};
inps[i] = RelayoutPlaceholder::make(inps[i], key).node();
}
}
auto ret = serialization::copy_opr_shallow(*opr, inps, opr->config());
......@@ -5422,4 +3190,5 @@ std::unique_ptr<EnableNCHW64Pass> EnableNCHW64Pass::make_nchw64_converter() {
return ret;
MIDOUT_E
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -227,6 +227,7 @@ namespace gopt {
VarReplaceCheckFlag m_var_replace_check_flag =
VarReplaceCheckFlag::CHECK_ALL;
class RelayoutPlaceholder;
friend class ShuffleShuffleRemovePass;
public:
TensorReformatPass& set_var_replace_check_flag(VarReplaceCheckFlag flag) {
......
......@@ -49,10 +49,14 @@ enum class TensorFormats : uint32_t {
KRSCk8 = 21, ///< [K/8, R, S, C, K%8]
// NCHW4
KCRSc4 = 22, ///< [K, C/4, R, S, C%4]
GKCRSc4 = 23, ///< [G, K, C/4, R, S, C%4]
// default weight format
KCRS = 22, ///< [K, C, R, S]
GKCRS = 23, ///< [G, K, C, R, S]
C11RS = 24, ///< [C, 1, 1, R, S]
KCRS = 24, ///< [K, C, R, S]
GKCRS = 25, ///< [G, K, C, R, S]
C11RS = 26, ///< [C, 1, 1, R, S]
};
class ReformatManager : public NonCopyableObj {
......@@ -60,16 +64,20 @@ class ReformatManager : public NonCopyableObj {
public:
using ReformatImpl = thin_function<VarNode*(const VarNodeArray&)>;
enum class Attribute : uint32_t {
DEFAULT = 0,
IMAGE2D = 1 << 0,
IC_SMALL = 1 << 1,
};
struct ReformatKey {
enum class Attribute : uint32_t {
DEFAULT = 0,
IMAGE2D = 1 << 0,
IC_SMALL = 1 << 1,
};
TensorFormats input_format, output_format;
DTypeEnum input_dtype, output_dtype;
Attribute attribute;
std::string to_string() const;
ReformatKey()
: input_dtype{DTypeEnum::Float32},
output_dtype{DTypeEnum::Float32},
attribute{Attribute::DEFAULT} {}
ReformatKey(TensorFormats input_format_, TensorFormats output_format_,
Attribute attribute_ = Attribute::DEFAULT,
DTypeEnum input_dtype_ = DTypeEnum::Float32,
......@@ -86,11 +94,13 @@ public:
bool operator()(const ReformatKey& lhs,
const ReformatKey& rhs) const;
};
ReformatKey& deduce_reformat_dtype_enum(const DType& dt);
};
using ReformatCache =
std::unordered_map<ReformatKey, ReformatImpl, ReformatKey::Hash,
ReformatKey::Equal>;
const ReformatImpl& get(const ReformatKey& key) const;
ReformatImpl get(const ReformatKey& key) const;
ReformatImpl get(ReformatKey&& key) const { return get(key); }
static const ReformatManager& instance();
private:
......
/**
* \file src/gopt/test/reformat_manager.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "./helper.h"
#include "megbrain/gopt/reformat_manager.h"
#include "megbrain/opr/tensor_manip.h"
using namespace mgb;
using namespace gopt;
TEST(TestReformatManager, Feature) {
constexpr size_t N = 16, C = 128, H = 7, W = 7;
HostTensorGenerator<> gen;
using ReformatKey = ReformatManager::ReformatKey;
auto src_format = TensorFormats::NHWC, dst_format = TensorFormats::NCHWc64;
ReformatKey key{src_format, dst_format};
auto reformat = ReformatManager::instance().get(key);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto r = [](VarNode* inp) {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make(
{sub(0), sub(1), sub(2), sub(3) / 64, cv(64)}, 0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 3, 1, 2, 4});
return y1;
};
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto x = mkvar("x", {N, H, W, C});
auto y1 = SymbolVar(reformat({x.node()}));
auto y2 = r(x.node());
size_t nr_shapeof = 0;
size_t nr_reshape = 0;
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) {
if (o->same_type<opr::GetVarShape>())
nr_shapeof++;
if (o->same_type<opr::Reshape>())
nr_reshape++;
}}
.add(y1.node()->owner_opr());
ASSERT_EQ(nr_shapeof, 1);
ASSERT_EQ(nr_reshape, 1);
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)});
func1->execute();
auto func2 = graph->compile({make_callback_copy(y2, t2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
TEST(TestReformatManager, Weight) {
constexpr size_t G = 8, K = 128, C = 128, R = 3, S = 3;
HostTensorGenerator<> gen;
using ReformatKey = ReformatManager::ReformatKey;
auto src_format = TensorFormats::GKCRS,
dst_format = TensorFormats::GKCRSk4c4;
ReformatKey key{src_format, dst_format};
auto reformat = ReformatManager::instance().get(key);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto r = [](VarNode* inp) {
auto x = SymbolVar(inp);
auto xshp = opr::GetVarShape::make(x);
auto cv = [&x](int v) { return x.make_scalar(v); };
auto sub = [&xshp, &cv](int idx) {
return opr::IndexAt::make(xshp, {{0, cv(idx)}});
};
auto tshp0 = opr::Concat::make({sub(0), sub(1) / 4, cv(4), sub(2) / 4,
cv(4), sub(3), sub(4)},
0),
tshp1 = opr::Concat::make({sub(0), sub(1) / 4, sub(2) / 4, sub(3),
sub(4), cv(4), cv(4)},
0);
auto y0 = opr::Reshape::make(x, tshp0);
auto y1 = opr::Dimshuffle::make(y0, {0, 1, 3, 5, 6, 2, 4});
auto y2 = opr::Reshape::make(y1, tshp1);
return y2;
};
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp)).rename(name);
};
auto w = mkvar("w", {G, K / G, C / G, R, S});
auto y1 = SymbolVar(reformat({w.node()}));
auto y2 = r(w.node());
size_t nr_shapeof = 0;
size_t nr_reshape = 0;
cg::DepOprIter{[&nr_shapeof, &nr_reshape](cg::OperatorNodeBase* o) {
if (o->same_type<opr::GetVarShape>())
nr_shapeof++;
if (o->same_type<opr::Reshape>())
nr_reshape++;
}}
.add(y1.node()->owner_opr());
ASSERT_EQ(nr_shapeof, 1);
ASSERT_EQ(nr_reshape, 1);
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)});
func1->execute();
auto func2 = graph->compile({make_callback_copy(y2, t2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
TEST(TestReformatManager, InvalidKey) {
using ReformatKey = ReformatManager::ReformatKey;
using Attribute = ReformatKey::Attribute;
auto src_format = TensorFormats::GKCRS,
dst_format = TensorFormats::GKCRSk4c4;
Attribute attribute = Attribute::IMAGE2D;
ReformatKey key{src_format, dst_format, attribute};
ASSERT_THROW(ReformatManager::instance().get(key), AssertionError);
}
TEST(TestReformatManager, InputChannelSmall) {
constexpr size_t N = 16, C = 3, H = 224, W = 224;
auto cn = CompNode::load("cpux");
HostTensorGenerator<> gen;
using ReformatKey = ReformatManager::ReformatKey;
using Attribute = ReformatKey::Attribute;
auto src_format = TensorFormats::NCHW, dst_format = TensorFormats::NCHWc4;
ReformatKey key{src_format, dst_format, Attribute::IC_SMALL};
auto reformat = ReformatManager::instance().get(key);
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
auto r = [](VarNode* inp) {
auto x = SymbolVar(inp);
auto y = opr::RelayoutFormat::make(
x, megdnn::param::RelayoutFormat::Mode::NCHW_NCHW4_IC_SMALL);
return y;
};
auto mkvar = [&](const char* name, const TensorShape& shp) {
return opr::Host2DeviceCopy::make(*graph, gen(shp, cn)).rename(name);
};
auto x = mkvar("x", {N, C, H, W});
auto y1 = SymbolVar(reformat({x.node()}));
auto y2 = r(x.node());
HostTensorND t1, t2;
auto func1 = graph->compile({make_callback_copy(y1, t1)});
func1->execute();
auto func2 = graph->compile({make_callback_copy(y2, t2)});
func2->execute();
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册