提交 fca19535 编写于 作者: M Megvii Engine Team

feat(gopt): add nhwc fuse conv typecvt optpass

GitOrigin-RevId: adc230120323835d27786e8908e0f427f3e2f2df
上级 2fc73585
......@@ -121,7 +121,10 @@ bool ConvBiasForwardImpl::AlgoFallbackNCHWQS8::is_available(
auto&& param = args.opr->param();
bool is_format_ok = param.format == param::ConvBias::Format::NCHW;
bool is_version_ok = CUDNN_VERSION >= 7500;
bool is_dtype_ok = args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8;
bool is_dtype_ok =
(args.src_layout->dtype.enumv() == DTypeEnum::QuantizedS8 &&
(args.dst_layout->dtype.enumv() != DTypeEnum::QuantizedS4 ||
args.dst_layout->dtype.enumv() != DTypeEnum::Quantized4Asymm));
bool is_bias_ok =
args.bias_layout->ndim == 0 ||
(args.bias_layout->ndim == 4 && args.bias_layout->shape[0] == 1 &&
......
......@@ -31,6 +31,11 @@ bool ConvBiasForwardImpl::AlgoCUDNNConv::is_available(
}
}
if (args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4 ||
args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm) {
return false;
}
// FIXME: cudnn cannot handle the case when the initial value of dst tensor
// contains nan and beta is zero, because the result of 0.f * nan is still
// nan
......
......@@ -24,6 +24,11 @@ bool ConvBiasForwardImpl::AlgoMatmul8x8x32::is_available(
if (!is_compute_capability_required(6, 1))
return false;
if (args.dst_layout->dtype.enumv() == DTypeEnum::Quantized4Asymm ||
args.dst_layout->dtype.enumv() == DTypeEnum::QuantizedS4) {
return false;
}
auto dst_layout = *args.dst_layout;
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
dst_layout.dtype = DType();
......
/**
* \file src/gopt/impl/folding_conv_typecvt.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/gopt/inference.h"
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/opr/utility.h"
#include "megbrain/serialization/opr_shallow_copy.h"
#include "megdnn/opr_param_defs.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/utils/hash_ct.h"
#include "midout.h"
#include "megbrain/gopt/reformat_manager.h"
#if CUDA_VERSION >= 10020
MIDOUT_DECL(megbrain_folding_conv_typecvt)
#define MIDOUT_B(tag) \
MIDOUT_BEGIN(megbrain_folding_conv_typecvt, midout_iv(MGB_HASH_STR(tag))) {
#define MIDOUT_E \
} \
MIDOUT_END();
using namespace mgb;
using namespace gopt;
using ReformatKey = ReformatManager::ReformatKey;
/* ==================== FoldingConvBiasTypecvtPass ================= */
const char* FoldingConvBiasTypecvtPass::name() const {
return mgb_cstr_log("folding conv bias typecvt pass");
}
void FoldingConvBiasTypecvtPass::apply(OptState& opt) const {
MIDOUT_B("FoldingConvBiasTypecvtPass::apply");
using DepType = cg::OperatorNodeProp::DepType;
ThinHashMap<OperatorNodeBase*,
SmallVector<std::pair<OperatorNodeBase*, DepType>>>
readers;
static const ThinHashSet<Typeinfo*> opr_type_list = {
opr::TypeCvt::typeinfo(), opr::ConvBias::typeinfo()};
opt.graph().iter([&readers](OperatorNodeBase* opr) {
for (auto&& i : opr->node_prop().dep_map()) {
if (opr_type_list.count(i.first->owner_opr()->dyn_typeinfo())) {
readers[i.first->owner_opr()].emplace_back(opr, i.second);
}
}
});
auto rewriter = opt.graph().make_rewriter();
auto try_conv_typecvt = [&rewriter, &readers](OperatorNodeBase* opr) {
ThinHashSet<OperatorNodeBase*> opr_set;
ThinHashSet<OperatorNodeBase*> reader_set;
// check typecvt
auto typecvt = try_cast_as_op<opr::TypeCvt>(opr);
if (typecvt == nullptr)
return false;
auto inp_dtype_typecvt = typecvt->input(0)->dtype(),
out_dtype_typecvt = typecvt->output(0)->dtype();
bool is_s82f32 = inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype_typecvt.enumv() == DTypeEnum::Float32;
bool is_s82s4 =
inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8 &&
(out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 ||
out_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm);
bool is_s42s8 =
(inp_dtype_typecvt.enumv() == DTypeEnum::QuantizedS4 ||
inp_dtype_typecvt.enumv() == DTypeEnum::Quantized4Asymm) &&
out_dtype_typecvt.enumv() == DTypeEnum::QuantizedS8;
if (!(is_s82f32 || is_s82s4 || is_s42s8))
return false;
opr_set.insert(opr);
// check conv bias
auto conv_bias =
try_cast_as_op<opr::ConvBias>(typecvt->input(0)->owner_opr());
if (conv_bias == nullptr)
return false;
auto inp_dtype_conv = conv_bias->input(0)->dtype(),
out_dtype_conv = conv_bias->input(0)->dtype();
bool is_s8nhwc = inp_dtype_conv.enumv() == DTypeEnum::QuantizedS8 &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NHWC;
bool is_s4nhwc =
(inp_dtype_conv.enumv() == DTypeEnum::QuantizedS4 ||
inp_dtype_conv.enumv() == DTypeEnum::Quantized4Asymm) &&
out_dtype_conv.enumv() == inp_dtype_conv.enumv() &&
conv_bias->param().format ==
megdnn::param::ConvBias::Format::NHWC;
if (!(is_s8nhwc || is_s4nhwc))
return false;
if (conv_bias->input().size() != 3)
return false;
opr_set.insert(conv_bias);
for (auto&& i : readers[conv_bias]) {
if (i.second & DepType::DEV_VALUE) {
reader_set.insert(i.first);
}
}
for (auto reader : reader_set) {
if (opr_set.count(reader) <= 0) {
return false;
}
}
auto src = rewriter.get_var(conv_bias->input(0)),
filter = rewriter.get_var(conv_bias->input(1)),
bias = rewriter.get_var(conv_bias->input(2));
auto new_bias =
(out_dtype_typecvt.enumv() == DTypeEnum::Float32)
? opr::TypeCvt::make(bias, dtype::Float32()).node()
: bias;
auto new_param = conv_bias->param();
new_param.format = megdnn::param::ConvBias::Format::NHWC;
auto conv_bias_typecvt = opr::ConvBias::make(
src, filter, new_bias, new_param, conv_bias->execution_policy(),
OperatorNodeConfig{out_dtype_typecvt});
rewriter.replace_var(opr->output(0), conv_bias_typecvt.node(),
mgb_cstr_log("replace conv_bias(NHWC) + typecvt "
"to conv_bias(NHWC)"));
return true;
};
auto on_opr = [&try_conv_typecvt, &rewriter](OperatorNodeBase* opr) {
if (!try_conv_typecvt(opr)) {
rewriter.auto_replace_outputs(opr);
}
};
opt.graph().iter(on_opr);
rewriter.apply_inplace();
MIDOUT_E
}
#endif
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -835,6 +835,7 @@ const GraphOptimizer& GraphOptimizer::add_passes_for_graph_tuning_options(
add_pass<FuseWarpPerspectiveDimshufflePass>();
#if CUDA_VERSION >= 10020
add_pass<FoldingConvBiasDimshufflePass>();
add_pass<FoldingConvBiasTypecvtPass>();
#endif
});
#undef cb
......
......@@ -57,7 +57,10 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
TensorFormats::NCHW, TensorFormats::NHWC,
TensorFormats::NCHWc4, TensorFormats::NCHWc32,
TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {base_opr_format, base_tensor_format, Target::CUDA};
Attribute attribute = {
base_opr_format, base_tensor_format, Target::CUDA,
LayoutTransformContext::ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats),
attribute);
......@@ -67,8 +70,9 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4})
.add_opr_config(opr::ConvolutionForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4})
.add_opr_config(opr::ConvolutionBackwardData::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4})
.add_opr_config(
opr::ConvolutionBackwardData::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4, OprFormat::NHWC})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC,
......
......@@ -512,7 +512,7 @@ struct ConvTensorFormatsDispatcherImpl<opr::ConvolutionBackwardData,
const auto& conv = opr->cast_final_safe<Opr>();
OprTensorFormatsConfiguration config;
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NCHW4;
config.opr_format = OprFormat::NHWC;
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
available &=
......
......@@ -481,6 +481,12 @@ namespace gopt {
const char* name() const override;
void apply(OptState& opt) const override;
};
class FoldingConvBiasTypecvtPass final : public Pass {
public:
const char* name() const override;
void apply(OptState& opt) const override;
};
#endif
/*!
......
......@@ -585,6 +585,7 @@ TEST(TestLayoutTransform, DetectionHead) {
using OprFormat = LayoutTransformContext::OprFormat;
using OprList = LayoutTransformContext::OprList;
using Attribute = LayoutTransformContext::Attribute;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
using Target = LayoutTransformContext::Target;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(),
......@@ -600,8 +601,8 @@ TEST(TestLayoutTransform, DetectionHead) {
TensorFormats::NCHW, TensorFormats::NHWC,
TensorFormats::NCHWc4, TensorFormats::NCHWc32,
TensorFormats::NCHWc64, TensorFormats::CHWNc4};
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW,
Target::UNSPEC};
Attribute attribute = {OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats),
attribute);
......@@ -611,8 +612,9 @@ TEST(TestLayoutTransform, DetectionHead) {
OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4})
.add_opr_config(opr::ConvolutionForward::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4})
.add_opr_config(opr::ConvolutionBackwardData::typeinfo(),
{OprFormat::NCHW, OprFormat::NCHW4})
.add_opr_config(
opr::ConvolutionBackwardData::typeinfo(),
{OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC,
......@@ -630,6 +632,7 @@ TEST(TestLayoutTransform, DetectionHead) {
.add_pass<ShuffleShuffleRemovePass>()
.add_pass(FuseNCHW4Int8Preprocess::make())
.add_pass<FoldingConvBiasDimshufflePass>()
.add_pass<FoldingConvBiasTypecvtPass>()
.add_pass<ParamFusePass>()
.add_pass<ParamMergePass>()
.apply(SymbolVarArray{y})
......@@ -656,7 +659,8 @@ TEST(TestLayoutTransform, DetectionHead) {
/// check first conv format
const auto& first_conv = find_opr<opr::ConvBiasForward>(v);
const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>();
ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW4_NHWC);
ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NHWC);
ASSERT_EQ(cast.output()[0]->dtype().enumv(), DTypeEnum::Quantized4Asymm);
}
#endif
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册