提交 30976c23 编写于 作者: M Megvii Engine Team

fix(mgb/gopt): fix global layout transform

add a special opr_format modify function for concat operators to modify concat axis when input's layout has been changed

GitOrigin-RevId: 409420805714c5bbd617f66df1a75fe3e68439c6
上级 b6c0e8d0
......@@ -831,7 +831,7 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
dst[4] = 32;
} else if (param().format == Param::Format::NCHW88) {
megdnn_assert(
src.ndim == 5 || src.ndim == 4,
src.ndim == 5 || (src.ndim == 4 && src[1] <= 8),
"invalid src ndim for NCHW88, expected=5 or 4, got=%zu", src.ndim);
dst.ndim = 5;
dst[0] = src[0];
......@@ -854,7 +854,7 @@ typename ConvolutionBase<Parameter>::CanonizedFilterMeta ConvolutionBase<Paramet
param().format == Param::Format::NCHW44 ||
param().format == Param::Format::NCHW44_DOT) {
megdnn_assert(
src.ndim == 5 || src.ndim == 4,
src.ndim == 5 || (src.ndim == 4 && src[1] <= 4),
"invalid src ndim for NCHW44, expected=5 or 4, got=%zu", src.ndim);
dst.ndim = 5;
dst[0] = src[0];
......
......@@ -840,7 +840,9 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
if (need_param_fuse) {
add_pass<ParamFusePass>();
add_pass<ParamMergePass>();
}
return *this;
}
......
......@@ -66,7 +66,6 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormatConfigID::NCHW, OprFormatConfigID::NHWC,
OprFormatConfigID::NCHW4_NCHW32, OprFormatConfigID::NCHW32_NCHW4,
OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4})
.add_opr_config(
......
......@@ -18,6 +18,7 @@
#include "megbrain/gopt/solver.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/sereg.h"
#include "megbrain/utils/hash_ct.h"
......@@ -64,11 +65,6 @@ void LayoutTransformPass::apply(OptState& opt) const {
auto&& base_cfg_id = m_ctx->attribute().base_config_id;
auto&& reformat_attribute = m_ctx->attribute().reformat_attribute;
ThinHashMap<VarNode*, TensorFormats> var2fmts;
static ThinHashSet<Typeinfo*> format_aware_oprs = {
#define cb(_Opr) opr::_Opr::typeinfo(),
FOREACH_FORMAT_AWARE_OPR(cb)
#undef cb
};
auto rewriter = opt.graph().make_rewriter();
auto on_opr = [&opr_configs, &base_fmt, &base_cfg_id, &reformat_attribute,
&rewriter, &solution, &var2fmts,
......@@ -141,8 +137,12 @@ void LayoutTransformPass::apply(OptState& opt) const {
new_inp[i] = new_var;
}
VarNode* new_out;
if (format_aware_oprs.count(opr->dyn_typeinfo()) > 0) {
new_out = intl::modify_opr_format(opr_fmt.val(), new_inp, opr);
if (intl::has_opr_format_modifier(opr)) {
intl::OprFormatInfo opr_format_info;
opr_format_info.opr_format = opr_fmt.val();
opr_format_info.tensor_formats = {
base_fmt, opr_format_to_tensor_formats(opr_fmt.val())};
new_out = intl::modify_opr_format(opr_format_info, new_inp, opr);
} else {
new_out = serialization::copy_opr_shallow(*opr, new_inp, opr->config())
->output(0);
......
......@@ -11,10 +11,12 @@
*/
#include "./opr_format_modifier.h"
#include "./utils.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/sereg.h"
#include "midout.h"
......@@ -201,6 +203,37 @@ INST(ConvolutionBackwardData)
INST(PoolingForward)
#undef APPLY
#undef INST
VarNode* modify_concat_opr_format(
gopt::intl::OprFormatInfo::TensorFormatsInfo tensor_formats,
const VarNodeArray& i, const cg::OperatorNodeBase* opr) {
auto base_format = tensor_formats.from;
auto tensor_format = tensor_formats.to;
int axis = opr->cast_final_safe<Concat>().axis();
/// modify axis
using Dimension = megdnn::Dimension;
static constexpr uint32_t UNDETERMINED_EXTENT = Dimension::UNDETERMINED_EXTENT;
auto orig_shape = tensor_formats_to_named_tensor_shape(base_format);
auto target_shape = tensor_formats_to_named_tensor_shape(tensor_format);
mgb_assert(
static_cast<size_t>(axis) < orig_shape.ndim,
"invalid axis of concat opr(axis:%d,shp:%s)", axis,
orig_shape.to_string().c_str());
if (orig_shape[axis].extent() != UNDETERMINED_EXTENT)
return nullptr;
auto axis_name = orig_shape[axis].name();
int new_axis = target_shape.ndim;
for (size_t i = 0; i < target_shape.ndim; ++i) {
if (target_shape[i].name() == axis_name &&
target_shape[i].extent() == UNDETERMINED_EXTENT) {
new_axis = i;
break;
}
}
if (static_cast<size_t>(new_axis) >= target_shape.ndim)
return nullptr;
return opr::Concat::make(i, new_axis, opr->config()).node();
}
} // namespace
namespace mgb {
......@@ -275,13 +308,16 @@ INST(Resize, 2);
#undef INST
VarNode* modify_opr_format(
opr::ConvBias::Param::Format opr_format, const VarNodeArray& i,
OprFormatInfo opr_format_info, const VarNodeArray& i,
const cg::OperatorNodeBase* opr) {
#define cb(_Opr) \
if (opr->dyn_typeinfo() == _Opr::typeinfo()) { \
return OprFormatModifier<_Opr>::make(opr_format, i, opr); \
return OprFormatModifier<_Opr>::make(opr_format_info.opr_format, i, opr); \
} else
FOREACH_FORMAT_AWARE_OPR(cb) {
FOREACH_FORMAT_AWARE_OPR(cb)
if (opr->dyn_typeinfo() == opr::Concat::typeinfo()) {
return modify_concat_opr_format(opr_format_info.tensor_formats, i, opr);
} else {
mgb_throw(
InternalError, "invalid format aware operator(got:%s)",
opr->dyn_typeinfo()->name);
......@@ -302,6 +338,28 @@ bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr)
InternalError, "invalid multi-algo operator(got:%s)",
opr->dyn_typeinfo()->name);
}
#undef cb
}
bool has_opr_format(const cg::OperatorNodeBase* opr) {
bool ret = false;
#define cb(_Opr) ret |= opr->dyn_typeinfo() == _Opr::typeinfo();
FOREACH_FORMAT_AWARE_OPR(cb)
#undef cb
return ret;
}
bool has_opr_format_modifier(const cg::OperatorNodeBase* opr) {
bool ret = false;
#define cb(_Opr) ret |= opr->dyn_typeinfo() == _Opr::typeinfo();
FOREACH_MODIFY_OPR_FORMAT_OPR(cb)
#undef cb
return ret;
}
bool allow_aligned_layout(const cg::OperatorNodeBase* opr) {
return opr->dyn_typeinfo() != opr::Concat::typeinfo() &&
opr->dyn_typeinfo() != opr::Reduce::typeinfo();
}
} // namespace intl
......
......@@ -16,17 +16,36 @@
namespace mgb {
namespace gopt {
enum class TensorFormats : uint32_t;
namespace intl {
#define FOREACH_FORMAT_AWARE_OPR(cb) \
cb(Convolution) cb(ConvBiasForward) cb(ConvolutionBackwardData) cb(PoolingForward) \
cb(WarpPerspective) cb(Resize)
#define FOREACH_MODIFY_OPR_FORMAT_OPR(cb) FOREACH_FORMAT_AWARE_OPR(cb) cb(Concat)
bool has_available_algo(const VarNodeArray& i, const cg::OperatorNodeBase* opr);
struct OprFormatInfo {
opr::Convolution::Param::Format opr_format;
struct TensorFormatsInfo {
TensorFormats from;
TensorFormats to;
};
TensorFormatsInfo tensor_formats;
};
VarNode* modify_opr_format(
opr::Convolution::Param::Format opr_format, const VarNodeArray& i,
OprFormatInfo opr_format, const VarNodeArray& i,
const cg::OperatorNodeBase* opr);
bool has_opr_format(const cg::OperatorNodeBase* opr);
bool has_opr_format_modifier(const cg::OperatorNodeBase* opr);
bool allow_aligned_layout(const cg::OperatorNodeBase* opr);
} // namespace intl
} // namespace gopt
} // namespace mgb
......
......@@ -625,7 +625,13 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_HYBRID> {
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &= check_dtype(opr->output(0)->dtype(), false);
// FIXME: hack for nchw nchw44 hybrid mode
static_assert(
std::is_same<Opr, opr::ConvolutionForward>::value ||
std::is_same<Opr, opr::ConvBiasForward>::value,
"nchw44 hybrid only support conv or conv_bias opr");
size_t in_channel = opr->input(0)->shape()[1];
available &= in_channel <= 4_z && check_dtype(opr->output(0)->dtype(), false);
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
config.input_tensor_formats = {
......@@ -696,7 +702,14 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW88_HYBRID> {
TensorType tensor_type = i == 1 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &= opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
// FIXME: hack for nchw nchw88 hybrid mode
static_assert(
std::is_same<Opr, opr::ConvolutionForward>::value ||
std::is_same<Opr, opr::ConvBiasForward>::value,
"nchw nchw88 hybrid mode only support conv or conv_bias opr");
size_t in_channel = opr->input(0)->shape()[1];
available &= in_channel <= 8_z &&
opr->output(0)->dtype().enumv() == DTypeEnum::Float32;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
// setup tensor formats
......@@ -783,6 +796,13 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NCHW44_DOT_HYBRID
opr->output(0)->dtype().enumv() == DTypeEnum::Quantized8Asymm;
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == Opr::Param::Sparse::DENSE;
// FIXME: hack for nchw nchw44 dot hybrid mode
static_assert(
std::is_same<Opr, opr::ConvolutionForward>::value ||
std::is_same<Opr, opr::ConvBiasForward>::value,
"nchw44 dot hybrid only support conv or conv_bias opr");
size_t in_channel = opr->input(0)->shape()[1];
available &= in_channel <= 4_z;
// setup tensor formats
config.input_tensor_formats = {
TensorFormats::NCHW, TensorFormats::KRSCk4, TensorFormats::NCHWc4,
......@@ -940,6 +960,8 @@ StaticData::StaticData() {
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW4);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, CHWN4);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW32_NCHW4);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW4_NCHW32);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW64);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW88);
......
......@@ -19,6 +19,7 @@
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/plugin/base.h"
#include "megbrain/serialization/sereg.h"
......@@ -202,20 +203,43 @@ float ProfilerImpl::profile_operator(
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
graph->options().var_sanity_check_first_run = false;
OperatorNodeBase* new_opr;
/// \note: Concat operators are specially treated. The reasons are as
/// follows:
/// 1. Padding the input varnodes of the
/// Concat opr is not allowed. If we pad the input varnodes of Concat opr, the
/// padding information of output varnode should be propagated in the layout
/// selection algorithm. But this feature has not been implemented now.
/// 2. The axis of concat operator should be modified, because the layouts of the
/// input varnodes has been modified. So we handle the new axis in the OprMaker
/// function.
bool allow_aligned = intl::allow_aligned_layout(opr);
VarNodeArray new_inps(opr->input().size());
for (size_t i = 0; i < opr->input().size(); ++i) {
auto&& var = opr->input(i);
auto&& cn = var->comp_node();
auto&& dtype = var->dtype();
auto dval = std::make_shared<DeviceTensorND>(cn, dtype);
auto aligned_tensor_shape = ReformatManager::make_aligned_tensor_shape(
var, base_format, tensor_format, extra_attribute);
dval->resize(aligned_tensor_shape);
auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval);
new_inps[i] = aligned_var.node();
}
auto new_opr = serialization::copy_opr_shallow(
auto new_shape = ReformatManager::try_make_tensor_shape(
var, base_format, tensor_format, extra_attribute, allow_aligned);
if (new_shape.ndim == 0)
return PROFILE_TIME_OUT;
dval->resize(new_shape);
auto new_var = opr::VolatileSharedDeviceTensor::make(*graph, dval);
new_inps[i] = new_var.node();
}
if (intl::has_opr_format_modifier(opr)) {
intl::OprFormatInfo opr_format_info;
opr_format_info.tensor_formats = {base_format, tensor_format};
auto new_var = intl::modify_opr_format(opr_format_info, new_inps, opr);
if (new_var)
new_opr = new_var->owner_opr();
else
return PROFILE_TIME_OUT;
} else {
new_opr = serialization::copy_opr_shallow(
*opr, new_inps, opr->config(), {graph.get()});
}
if (!m_opr_filter(opr, new_opr))
return PROFILE_TIME_OUT;
auto y = new_opr->output(0);
......@@ -248,6 +272,8 @@ float ProfilerImpl::profile_operator(
ReformatAttribute extra_attribute) const {
auto graph = ComputingGraph::make();
graph->options().graph_opt_level = 0;
graph->options().graph_opt.weight_preprocess =
opr->owner_graph()->options().graph_opt.weight_preprocess;
graph->options().var_sanity_check_first_run = false;
VarNodeArray new_inps(opr->input().size());
size_t i = 0;
......@@ -274,8 +300,11 @@ float ProfilerImpl::profile_operator(
config.input_tensor_formats[i], extra_attribute);
}
dval->resize(aligned_shape);
auto aligned_var = opr::VolatileSharedDeviceTensor::make(*graph, dval);
new_inps[i] = aligned_var.node();
if (config.input_tensor_types[i] == TensorType::WEIGHT) {
new_inps[i] = opr::SharedDeviceTensor::make_const(*graph, dval).node();
} else {
new_inps[i] = opr::VolatileSharedDeviceTensor::make(*graph, dval).node();
}
}
for (; i < opr->input().size(); ++i) {
auto&& var = opr->input(i);
......@@ -291,7 +320,9 @@ float ProfilerImpl::profile_operator(
auto imm = opr::ImmutableTensor::make(*graph, *hval);
new_inps[i] = imm.node();
}
VarNode* y = mgb::gopt::intl::modify_opr_format(config.opr_format, new_inps, opr);
intl::OprFormatInfo opr_format_info;
opr_format_info.opr_format = config.opr_format;
VarNode* y = mgb::gopt::intl::modify_opr_format(opr_format_info, new_inps, opr);
static const ThinHashSet<Typeinfo*> multi_algo_oprs = {
opr::Convolution::typeinfo(),
opr::ConvBiasForward::typeinfo(),
......
......@@ -587,9 +587,9 @@ const ReformatManager& ReformatManager::instance() {
return inst;
}
TensorShape ReformatManager::make_aligned_tensor_shape(
TensorShape ReformatManager::try_make_tensor_shape(
const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats,
ReformatKey::Attribute extra_attribute) {
ReformatKey::Attribute extra_attribute, bool allow_aligned) {
using Dimension = megdnn::Dimension;
static constexpr uint32_t UNDETERMINED_EXTENT = Dimension::UNDETERMINED_EXTENT;
auto orig_shape = tensor_formats_to_named_tensor_shape(orig_formats);
......@@ -623,8 +623,17 @@ TensorShape ReformatManager::make_aligned_tensor_shape(
: (orig_shape[idx] / target_shape[i]).extent();
if (mul)
tshp[i] = oshp[idx] * factor;
else
else {
if (allow_aligned)
tshp[i] = divup(oshp[idx], factor);
else if (!(oshp[idx] % factor)) {
tshp[i] = oshp[idx] / factor;
} else {
return TensorShape{};
}
}
/// hack for nhwc auto padding
if (name == Dimension::Name::C) {
size_t channel_alignment = target_shape[i].stride();
size_t channels = tshp[i] * channel_alignment;
......@@ -641,6 +650,15 @@ TensorShape ReformatManager::make_aligned_tensor_shape(
return tshp;
}
TensorShape ReformatManager::make_aligned_tensor_shape(
const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats,
ReformatKey::Attribute extra_attribute) {
auto tshp = ReformatManager::try_make_tensor_shape(
var, orig_formats, target_formats, extra_attribute);
mgb_assert(tshp.ndim);
return tshp;
}
TensorShape ReformatManager::make_aligned_weight_shape(
const VarNode* var, TensorFormats orig_formats, TensorFormats target_formats,
TensorFormats extra_formats, ReformatKey::Attribute extra_attribute) {
......
......@@ -93,6 +93,8 @@ static inline TensorFormats opr_format_to_tensor_formats(
return TensorFormats::NCHWc4;
case OprFormat::NCHW8:
return TensorFormats::NCHWc8;
case OprFormat::NCHW44_DOT:
return TensorFormats::NCHWc4;
default:
mgb_throw(
AssertionError, "format(%s) is not supported",
......@@ -171,7 +173,6 @@ static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape(
static_cast<uint32_t>(format));
}
}
} // namespace gopt
} // namespace mgb
......
......@@ -135,6 +135,13 @@ public:
const VarNode* orig_var, const ReformatKey& key,
const AlignmentDesc& extra_alignment = {}) const;
/// return empty shape, if shape of origin varnode does not satisfy the alignment
/// requirement of the target tensor formats
static TensorShape try_make_tensor_shape(
const VarNode* var, TensorFormats orig_formats,
TensorFormats target_formats,
ReformatKey::Attribute extra_attribute = ReformatKey::Attribute::DEFAULT,
bool allow_aligned = true);
static TensorShape make_aligned_tensor_shape(
const VarNode* var, TensorFormats orig_formats,
TensorFormats target_formats,
......
此差异由.gitattributes 抑制。
......@@ -20,7 +20,7 @@
# 2. 编译megbrain_test,并运行所有全局图优化相关测试:
# ./megbrain_test --gtest_filter="*LayoutTransform*"
# 3. 用这个脚本把所有的cache文件打包在一起
# python3 embed_cache.py -o cache_data.h -r $(ls /path/to/cache/*.cache)
# python3 embed_cache.py -o cache_data.h -r -a $(ls /path/to/cache/*.cache)
# 4. 将步骤1中的 define 语句改回原样,这样 profile 过程就会使用 cache 下来的数据。
# 5. 最后可以重新构建一下 megbrain_test ,确保测试结果正确。
import os.path
......@@ -44,9 +44,10 @@ def _u32(data):
class CacheDataGenerator:
_cache_files = None
def __init__(self, cache_files, remove_plat_info = True):
def __init__(self, cache_files, remove_plat_info=True, append_cache=True):
self._cache_files = cache_files
self._remove_plat_info = remove_plat_info
self._append_cache = append_cache
def _get_hash(self):
return _u32(self._hash.digest()[:4])
......@@ -71,6 +72,7 @@ class CacheDataGenerator:
return ','.join(ret)
def gen_cache_data_header(self, fout, src_map):
if not self._append_cache:
fout.write('// generated embed_cache.py\n')
fout.write('#include <vector>\n')
fout.write('#include <stdint.h>\n')
......@@ -89,7 +91,11 @@ static const std::vector<uint8_t> {} = {{
assert ext == ".cache", "ext: {}, fname {}".format(ext, fname)
assert base not in fname2cache_data, "duplicated kernel: " + base
fname2cache_data[base] = self.gen_cache_data(fname)
with open(output, 'w') as fout:
if self._append_cache:
mode = 'a'
else:
mode = 'w'
with open(output, mode) as fout:
self.gen_cache_data_header(fout, fname2cache_data)
logger.info('done')
......@@ -107,7 +113,15 @@ if __name__ == '__main__':
default=True,
help="whether remove platform infomation in the cache (default: True)"
)
parser.add_argument(
"-a",
"--append-cache",
action='store_true',
default=True,
help="whether append the cache (default: True)"
)
parser.add_argument('cache', help='cache files to be embedded', nargs='+')
args = parser.parse_args()
cache_generator = CacheDataGenerator(args.cache, args.remove_plat_info)
cache_generator = CacheDataGenerator(args.cache, args.remove_plat_info,
args.append_cache)
cache_generator.invoke(args.output)
......@@ -812,7 +812,9 @@ TEST(TestLayoutTransform, Resnet18_F16) {
.add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
.add_pass<ShuffleShuffleRemovePass>()
.add_pass(FuseNCHW4Int8Preprocess::make())
#if CUDA_VERSION >= 10020
.add_pass<FoldingConvBiasDimshufflePass>()
#endif
.add_pass<ParamFusePass>()
.add_pass<ParamMergePass>()
.apply({{output}})
......@@ -1205,4 +1207,112 @@ TEST(TestLayoutTransform, MobileNetV2_NCHW44_DOT) {
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
#if MGB_CUDA
TEST(TestLayoutTransform, Concat) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
cn.activate();
REQUIRE_CUDA_COMPUTE_CAPABILITY(6, 1);
constexpr size_t N = 16, C = 3, H = 736, W = 1280;
HostTensorGenerator<dtype::Uint8> gen;
auto graph = ComputingGraph::make();
auto h2d = opr::Host2DeviceCopy::make(*graph, gen({N, C, H, W}, cn));
auto data = opr::TypeCvt::make(h2d, dtype::Float32());
auto sub_128 = data + (-128);
auto x = opr::TypeCvt::make(sub_128, dtype::QuantizedS8(1.f));
auto mkcvar = [&](const char* name, const TensorShape& shp, const DType& dtype) {
return opr::TypeCvt::make(
opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name),
dtype);
};
auto w = mkcvar("w", {2, 3, 3, 3}, dtype::QuantizedS8(1.f));
auto b = mkcvar("b", {1, 2, 1, 1}, dtype::QuantizedS32(1.f));
opr::ConvBias::Param param;
param.format = opr::ConvBias::Param::Format::NCHW;
param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
param.stride_h = param.stride_w = 2;
param.pad_h = param.pad_w = 1;
auto conv_1 = opr::ConvBias::make(
x, w, b, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f)));
auto conv_1_cat = opr::Concat::make({conv_1, -conv_1}, 1);
auto w2 = mkcvar("w", {4, 4, 3, 3}, dtype::QuantizedS8(1.f));
auto b2 = mkcvar("b", {1, 4, 1, 1}, dtype::QuantizedS32(1.f));
auto conv_2 = opr::ConvBias::make(
conv_1_cat, w2, b2, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f)));
auto conv_2_cat = opr::Concat::make({conv_2, -conv_2}, 1);
auto w3 = mkcvar("w", {16, 8, 3, 3}, dtype::QuantizedS8(1.f));
auto b3 = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(1.f));
auto y = opr::ConvBias::make(
conv_2_cat, w3, b3, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f)));
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
S strategy = S::PROFILE;
gopt::modify_opr_algo_strategy_inplace({y}, strategy);
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(), opr::ElemwiseMultiType::typeinfo(),
opr::Elemwise::typeinfo(), opr::TypeCvt::typeinfo(),
opr::Concat::typeinfo(),
};
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NCHWc4};
Attribute attribute = {
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(TestLayoutTransform_Concat.data()),
TestLayoutTransform_Concat.size());
#else
auto profiler =
ProfilerBase::make_cached_profiler("TestLayoutTransform.Concat.cache");
#endif
std::unique_ptr<SolverBase> solver{
new DynamicProgrammingSolver(std::move(profiler))};
auto new_out_vars =
gopt::GraphOptimizer{}
.add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
.add_pass<ShuffleShuffleRemovePass>()
.add_pass(FuseNCHW4Int8Preprocess::make())
#if CUDA_VERSION >= 10020
.add_pass<FoldingConvBiasDimshufflePass>()
.add_pass<FoldingConvBiasTypecvtPass>()
#endif
.add_pass<ParamFusePass>()
.add_pass<ParamMergePass>()
.apply(SymbolVarArray{y})
.endpoint_vars();
const auto& v = new_out_vars[0];
using OutputSpecItem = cg::ComputingGraph::OutputSpecItem;
std::vector<OutputSpecItem> outs;
for (auto&& i : new_out_vars) {
outs.emplace_back(OutputSpecItem{i, {}});
}
GraphProfiler gprof{graph.get()};
auto func = graph->compile(outs);
func->execute();
gprof.to_json_full(func.get())->writeto_fpath(output_file("conv_cat.json"));
SmallVector<cg::OperatorNodeBase*> oprs;
auto cb = [&oprs](cg::OperatorNodeBase* opr) {
if (opr->same_type<opr::Concat>()) {
oprs.push_back(opr);
}
};
cg::DepOprIter{cb}.add(v.node()->owner_opr());
ASSERT_EQ(oprs.size(), 4);
ASSERT_EQ(oprs[0]->output(0)->shape().ndim, 4);
ASSERT_EQ(oprs[2]->output(0)->shape().ndim, 5);
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册