提交 f2db7b0d 编写于 作者: M Megvii Engine Team

feat(mgb/gopt): global layout transform support cuda fp16

GitOrigin-RevId: 1449c54215d053d2bd22c6f6fc5235c1a9fb560f
上级 ca7cec7a
......@@ -78,13 +78,13 @@ std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
OprFormatConfigID::NHWC})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW32,
OprFormatConfigID::NHWC, OprFormatConfigID::NCHW64,
OprFormatConfigID::CHWN4})
{OprFormatConfigID::NCHW, OprFormatConfigID::NCHW4,
OprFormatConfigID::NCHW32, OprFormatConfigID::NHWC,
OprFormatConfigID::NCHW64, OprFormatConfigID::CHWN4})
.add_opr_config(
opr::WarpPerspectiveForward::typeinfo(),
{OprFormatConfigID::NHWC, OprFormatConfigID::NCHW4,
OprFormatConfigID::NCHW64});
{OprFormatConfigID::NCHW, OprFormatConfigID::NHWC,
OprFormatConfigID::NCHW4, OprFormatConfigID::NCHW64});
return ctx;
}
......
......@@ -191,8 +191,11 @@ struct OprSingleInOutTensorFormatsDispatcherImpl<OprFormatConfigID::NHWC> {
config.typeinfo = opr->dyn_typeinfo();
config.opr_format = OprFormat::NHWC;
config.config_id = OprFormatConfigID::NHWC;
bool available = opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
bool f16_config = DNN_FLOAT16_SELECT(
(opr->input(0)->dtype().enumv() == DTypeEnum::Float16), true);
bool i4_config = opr->input(0)->dtype().enumv() == DTypeEnum::Quantized4Asymm ||
opr->input(0)->dtype().enumv() == DTypeEnum::QuantizedS4;
bool available = f16_config || i4_config;
config.input_dtypes = {opr->input(0)->dtype().enumv()};
config.input_tensor_types = {TensorType::FEATURE};
available &= opr->output(0)->dtype().enumv() == opr->input(0)->dtype().enumv();
......@@ -275,16 +278,22 @@ struct ConvTensorFormatsDispatcherImpl<Opr, OprFormatConfigID::NHWC> {
config.opr_format = OprFormat::NHWC;
config.config_id = OprFormatConfigID::NHWC;
auto check_dtype = [](const DType& dt) {
bool f16_config =
DNN_FLOAT16_SELECT((dt.enumv() == DTypeEnum::Float16), true);
bool i4_config = dt.enumv() == DTypeEnum::Quantized4Asymm ||
dt.enumv() == DTypeEnum::QuantizedS4;
bool i8_config = dt.enumv() == DTypeEnum::QuantizedS8;
return i4_config || i8_config;
return f16_config || i4_config || i8_config;
};
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
if (i == 2)
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32;
else {
if (i == 2) {
available &=
opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS32 ||
DNN_FLOAT16_SELECT(
opr->input(i)->dtype().enumv() == DTypeEnum::Float16,
true);
} else {
available &= check_dtype(opr->input(i)->dtype());
}
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
......@@ -866,12 +875,18 @@ struct ConvTensorFormatsDispatcherImpl<
config.config_id = OprFormatConfigID::NHWC;
bool available = true;
for (size_t i = 0; i < opr->input().size(); ++i) {
available &= opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8;
available &=
opr->input(i)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
DNN_FLOAT16_SELECT(
opr->input(i)->dtype().enumv() == DTypeEnum::Float16, true);
config.input_dtypes.emplace_back(opr->input(i)->dtype().enumv());
TensorType tensor_type = i == 0 ? TensorType::WEIGHT : TensorType::FEATURE;
config.input_tensor_types.emplace_back(tensor_type);
}
available &= opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8;
available &=
opr->output(0)->dtype().enumv() == DTypeEnum::QuantizedS8 ||
DNN_FLOAT16_SELECT(
opr->output(0)->dtype().enumv() == DTypeEnum::Float16, true);
config.output_dtypes.emplace_back(opr->output(0)->dtype().enumv());
available &= conv.param().sparse == opr::ConvBias::Param::Sparse::DENSE;
config.input_tensor_formats = {
......@@ -934,6 +949,7 @@ StaticData::StaticData() {
OPR_TENSOR_FORMATS_CONFIG_REG(ConvBias, NCHW44_DOT_HYBRID);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NHWC);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW4);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW44);
OPR_TENSOR_FORMATS_CONFIG_REG(ConvolutionForward, NCHW88);
......
......@@ -29,6 +29,8 @@
#include "./cache_data.h"
#endif
#include "megbrain/plugin/opr_io_dump.h"
using namespace mgb;
using namespace gopt;
using namespace serialization;
......@@ -748,6 +750,95 @@ TEST(TestLayoutTransform, CanonicalizeLayoutTransform) {
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
#if MGB_CUDA
TEST(TestLayoutTransform, Resnet18_F16) {
REQUIRE_GPU(1);
auto cn = CompNode::load("gpu0");
auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
auto sm_ver = prop.major * 10 + prop.minor;
if (sm_ver < 70) {
printf("This testcast ignored due to insufficient cuda cap(got: %d, "
"expected: %d)\n",
sm_ver, 70);
return;
}
Network network(cn);
auto output = make_resnet18(network, 16);
using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
S strategy = S::PROFILE;
gopt::modify_opr_algo_strategy_inplace({{output}}, strategy);
HostTensorND t1;
auto func1 = network.graph->compile({make_callback_copy(output, t1)});
func1->execute();
using OprFormatConfigID = LayoutTransformContext::OprFormatConfigID;
using OprList = LayoutTransformContext::OprList;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
OprList opr_list = {
opr::ConvBiasForward::typeinfo(), opr::ElemwiseMultiType::typeinfo(),
opr::Elemwise::typeinfo(), opr::TypeCvt::typeinfo(),
opr::PoolingForward::typeinfo(), opr::WarpPerspectiveForward::typeinfo(),
};
SmallVector<TensorFormats> available_tensor_formats = {
TensorFormats::NCHW, TensorFormats::NHWC};
Attribute attribute = {
OprFormatConfigID::NCHW, TensorFormats::NCHW, Target::UNSPEC,
ReformatAttribute::AUTO_PADDING_NHWC};
auto ctx = std::make_unique<LayoutTransformContext>(
std::move(opr_list), std::move(available_tensor_formats), attribute);
ctx->add_opr_config(
opr::ConvBiasForward::typeinfo(),
{OprFormatConfigID::NCHW, OprFormatConfigID::NHWC})
.add_opr_config(
opr::PoolingForward::typeinfo(),
{OprFormatConfigID::NCHW, OprFormatConfigID::NHWC});
#if MGB_WITH_CACHED_TEST
auto profiler = std::make_unique<ProfilerMock>(
static_cast<const uint8_t*>(TestLayoutTransform_Resnet18_F16.data()),
TestLayoutTransform_Resnet18_F16.size());
#else
auto profiler = ProfilerBase::make_cached_profiler(
"TestLayoutTransform.Resnet18_F16.cache");
#endif
std::unique_ptr<SolverBase> solver{
new DynamicProgrammingSolver(std::move(profiler))};
auto new_output =
gopt::GraphOptimizer{}
.add_pass(ConvertF32ToF16Pass::make(false))
.add_pass<FuseConvBiasNonlinPass>()
.add_pass<FuseConvBiasZPass>()
.add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
.add_pass<ShuffleShuffleRemovePass>()
.add_pass(FuseNCHW4Int8Preprocess::make())
.add_pass<FoldingConvBiasDimshufflePass>()
.add_pass<ParamFusePass>()
.add_pass<ParamMergePass>()
.apply({{output}})
.endpoint_vars();
auto new_out_var = new_output[0];
/// check global layout transform pass
auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var);
ASSERT_EQ(nr_dimshuffle, 4u);
/// check pass fuse conv bias with z
auto nr_elemwise = find_opr_num<opr::Elemwise>(new_out_var);
ASSERT_EQ(nr_elemwise, 4u);
/// 21 convolutions, 21 weights and 21 bias, total 42 parameters
const auto& param_merge = find_opr<opr::MultipleDeviceTensorHolder>(new_out_var);
ASSERT_EQ(param_merge.output().size(), 42u);
GraphProfiler gprof{network.graph.get()};
HostTensorND t2;
auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)});
func2->execute();
gprof.to_json_full(func2.get())->writeto_fpath(output_file("resnet18_f16.json"));
MGB_ASSERT_TENSOR_NEAR(t1, t2, 1e-3);
}
#endif
TEST(TestLayoutTransform, Resnet18_F32) {
auto cn = CompNode::load("cpu0");
......@@ -1115,4 +1206,5 @@ TEST(TestLayoutTransform, MobileNetV2_NCHW44_DOT) {
/// check correct
MGB_ASSERT_TENSOR_EQ(t1, t2);
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -38,8 +38,13 @@ SymbolVar Network::add_conv(
param.nonlineMode = opr::ConvBias::Param::NonlineMode::IDENTITY;
}
auto conv = opr::ConvBias::make(
f, weight, bias, param, {}, OperatorNodeConfig{out_dtype});
SymbolVar conv;
if (out_dtype.category() == DTypeCategory::QUANTIZED) {
conv = opr::ConvBias::make(
f, weight, bias, param, {}, OperatorNodeConfig{out_dtype});
} else {
conv = opr::ConvBias::make(f, weight, bias, param, {});
}
weight_idx++;
bias_idx++;
return conv;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册