layout_transform_pass.cpp 32.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12
/**
 * \file src/gopt/test/layout_transform_pass.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

13
#include "megbrain/gopt/layout_transform_pass.h"
14 15
#include "./network.h"
#include "megbrain/comp_node_env.h"
16
#include "megbrain/gopt/inference.h"
17 18 19
#include "megbrain/gopt/layout_transform_context.h"
#include "megbrain/gopt/profiler.h"
#include "megbrain/gopt/solver.h"
20 21 22 23 24 25
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
#include "megbrain/plugin/profiler.h"
#include "megbrain/serialization/serializer.h"

26 27 28 29 30 31
#define MGB_WITH_CACHED_TEST 1

#if MGB_WITH_CACHED_TEST
#include "./cache_data.h"
#endif

32 33 34 35
using namespace mgb;
using namespace gopt;
using namespace serialization;

36 37 38 39 40 41 42 43 44 45 46 47 48 49
namespace {
//! find first the operator of specific type; raise exception if not found
template <typename T>
T& find_opr(SymbolVar endpoint) {
    T* found = nullptr;
    auto cb = [&found](cg::OperatorNodeBase* opr) {
        if (!found && opr->same_type<T>()) {
            found = &opr->cast_final_safe<T>();
        }
    };
    cg::DepOprIter{cb}.add(endpoint.node()->owner_opr());
    mgb_assert(found, "not found opr from %s", endpoint.node()->name().c_str());
    return *found;
}
50

51 52 53 54 55 56 57 58 59 60 61
template <typename T>
size_t find_opr_num(SymbolVar endpoint) {
    size_t opr_num = 0;
    auto cb = [&opr_num](cg::OperatorNodeBase* opr) {
        if (opr->same_type<T>()) {
            opr_num++;
        }
    };
    cg::DepOprIter{cb}.add(endpoint.node()->owner_opr());
    return opr_num;
}
62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80

using OprFormat = Problem::OprFormat;
OprFormat tensor_formats_to_opr_format(TensorFormats tensor_format) {
    switch (tensor_format) {
        case TensorFormats::NCHW:
            return OprFormat::NCHW;
        case TensorFormats::NCHWc4:
            return OprFormat::NCHW4;
        case TensorFormats::NCHWc8:
            return OprFormat::NCHW8;
        case TensorFormats::NCHWc32:
            return OprFormat::NCHW32;
        case TensorFormats::NCHWc64:
            return OprFormat::NCHW64;
        case TensorFormats::NHWC:
            return OprFormat::NHWC;
        case TensorFormats::CHWNc4:
            return OprFormat::CHWN4;
        default:
81 82 83
            mgb_throw(
                    MegBrainError, "tensor format(%u) is not supported",
                    static_cast<uint32_t>(tensor_format));
84 85 86 87 88 89 90 91 92
    }
}

class ProfilerMock : public ProfilerImpl {
public:
    ProfilerMock(const uint8_t* bin, size_t size) {
        mgb_assert(bin != nullptr);
        ProfilerCache::inst().set_impl(
                std::make_unique<InFilePersistentCache>(bin, size));
93 94
        // disable saving platform information to make ci stable.
        ProfilerCache::inst().enable_device_info(false);
95 96 97
    }
    ~ProfilerMock() {
        // reset in memory cache
98
        ProfilerCache::inst().set_impl(std::make_unique<InMemoryPersistentCache>());
99 100 101
    }

private:
102 103 104 105 106 107 108
    float profile_operator(
            const OperatorNodeBase* opr, TensorFormats base_format,
            TensorFormats tensor_format,
            ReformatAttribute extra_attribute =
                    ReformatAttribute::DEFAULT) const override {
        ProfilerCache::Key key{
                opr, tensor_formats_to_opr_format(tensor_format), extra_attribute};
109 110 111 112 113
        auto ret = ProfilerCache::inst().get(key);
        if (ret.valid())
            return ret.val();
        mgb_assert(false);
    }
114 115 116 117 118 119
    float profile_operator(
            const OperatorNodeBase* opr,
            const OprTensorFormatsConfiguration& base_config,
            const OprTensorFormatsConfiguration& config,
            ReformatAttribute extra_attribute =
                    ReformatAttribute::DEFAULT) const override {
120 121 122 123 124 125 126 127
        ProfilerCache::Key key{opr, config.opr_format, extra_attribute};
        std::string tmp;
        tmp.reserve(key.blob().size);
        auto ret = ProfilerCache::inst().get(key);
        if (ret.valid())
            return ret.val();
        mgb_assert(false);
    }
128 129 130
    float profile_var_node(
            const VarNode* var, TensorFormats base_format,
            const ReformatKey& key) const override {
131 132 133 134 135 136 137
        ProfilerCache::Key pf_key{var, key};
        auto ret = ProfilerCache::inst().get(pf_key);
        if (ret.valid())
            return ret.val();
        mgb_assert(false);
    }
};
138 139
}  // namespace

M
Megvii Engine Team 已提交
140 141
#if MGB_CUDA
#if CUDA_VERSION >= 10020
142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
TEST(TestLayoutTransform, Resnet18_QS8) {
    REQUIRE_GPU(1);
    auto cn = CompNode::load("gpu0");
    auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
    auto sm_ver = prop.major * 10 + prop.minor;
    if (sm_ver < 75) {
        printf("This testcast ignored due to insufficient cuda cap(got: %d, "
               "expected: %d)\n",
               sm_ver, 75);
        return;
    }
    Network network(cn);
    /// batch size = 1 reduce test time
    auto output = make_resnet18(network, 16, dtype::QuantizedS8{1.f});
    using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
    S strategy = S::PROFILE;
    gopt::modify_opr_algo_strategy_inplace({{output}}, strategy);
159

160 161 162
    HostTensorND t1;
    auto func1 = network.graph->compile({make_callback_copy(output, t1)});
    func1->execute();
163

164 165
    using OprFormat = LayoutTransformContext::OprFormat;
    using OprList = LayoutTransformContext::OprList;
166
    using Target = LayoutTransformContext::Target;
167 168 169
    using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
    using Attribute = LayoutTransformContext::Attribute;
    OprList opr_list = {
M
Megvii Engine Team 已提交
170 171 172
            opr::ConvBiasForward::typeinfo(), opr::ElemwiseMultiType::typeinfo(),
            opr::Elemwise::typeinfo(),        opr::TypeCvt::typeinfo(),
            opr::PoolingForward::typeinfo(),  opr::WarpPerspectiveForward::typeinfo(),
173 174 175 176
    };
    SmallVector<TensorFormats> available_tensor_formats = {
            TensorFormats::NCHW, TensorFormats::NHWC, TensorFormats::NCHWc4,
            TensorFormats::NCHWc32, TensorFormats::CHWNc4};
M
Megvii Engine Team 已提交
177 178 179
    Attribute attribute = {
            OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
            ReformatAttribute::AUTO_PADDING_NHWC};
180
    auto ctx = std::make_unique<LayoutTransformContext>(
181 182 183 184 185 186 187 188
            std::move(opr_list), std::move(available_tensor_formats), attribute);
    ctx->add_opr_config(
               opr::ConvBiasForward::typeinfo(),
               {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC})
            .add_opr_config(
                    opr::PoolingForward::typeinfo(),
                    {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC,
                     OprFormat::CHWN4});
189 190
#if MGB_WITH_CACHED_TEST
    auto profiler = std::make_unique<ProfilerMock>(
191
            static_cast<const uint8_t*>(TestLayoutTransform_Resnet18_QS8.data()),
192 193 194 195 196
            TestLayoutTransform_Resnet18_QS8.size());
#else
    auto profiler = ProfilerBase::make_cached_profiler(
            "TestLayoutTransform.Resnet18_QS8.cache");
#endif
197 198
    std::unique_ptr<SolverBase> solver{
            new DynamicProgrammingSolver(std::move(profiler))};
M
Megvii Engine Team 已提交
199 200 201 202 203 204 205 206 207 208 209 210
    auto new_output =
            gopt::GraphOptimizer{}
                    .add_pass<FuseConvBiasNonlinPass>()
                    .add_pass<FuseConvBiasZPass>()
                    .add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
                    .add_pass<ShuffleShuffleRemovePass>()
                    .add_pass(FuseNCHW4Int8Preprocess::make())
                    .add_pass<FoldingConvBiasDimshufflePass>()
                    .add_pass<ParamFusePass>()
                    .add_pass<ParamMergePass>()
                    .apply({{output}})
                    .endpoint_vars();
211 212 213 214 215
    auto new_out_var = new_output[0];
    /// check global layout transform pass
    auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var);
    ASSERT_EQ(nr_dimshuffle, 3u);
    /// check pass fuse conv bias with z
M
Megvii Engine Team 已提交
216
    auto nr_elemwise_mult_type = find_opr_num<opr::ElemwiseMultiType>(new_out_var);
217 218
    ASSERT_EQ(nr_elemwise_mult_type, 4u);
    /// 21 convolutions, 21 weights and 21 bias, total 42 parameters
M
Megvii Engine Team 已提交
219
    const auto& param_merge = find_opr<opr::MultipleDeviceTensorHolder>(new_out_var);
220 221 222 223 224 225 226 227 228 229
    ASSERT_EQ(param_merge.output().size(), 42u);
    /// check first conv format
    const auto& first_conv = find_opr<opr::ConvBiasForward>(new_out_var);
    const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>();
    ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW4);

    GraphProfiler gprof{network.graph.get()};
    HostTensorND t2;
    auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)});
    func2->execute();
M
Megvii Engine Team 已提交
230
    gprof.to_json_full(func2.get())->writeto_fpath(output_file("resnet18_qs8.json"));
231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247
    /// check correct
    MGB_ASSERT_TENSOR_EQ(t1, t2);
}

TEST(TestLayoutTransform, Resnet18_QS4) {
    REQUIRE_GPU(1);
    auto cn = CompNode::load("gpu0");
    auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
    auto sm_ver = prop.major * 10 + prop.minor;
    if (sm_ver < 75) {
        printf("This testcast ignored due to insufficient cuda cap(got: %d, "
               "expected: %d)\n",
               sm_ver, 75);
        return;
    }
    Network network(cn);
    auto output = make_resnet18(network, 16, dtype::QuantizedS4{1.f});
248 249
    using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
    S strategy = S::PROFILE;
250 251 252 253 254
    gopt::modify_opr_algo_strategy_inplace({{output}}, strategy);

    HostTensorND t1;
    auto func1 = network.graph->compile({make_callback_copy(output, t1)});
    func1->execute();
255 256 257 258

    using OprFormat = LayoutTransformContext::OprFormat;
    using OprList = LayoutTransformContext::OprList;
    using Attribute = LayoutTransformContext::Attribute;
259 260
    using Target = LayoutTransformContext::Target;
    using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
261
    OprList opr_list = {
M
Megvii Engine Team 已提交
262 263 264
            opr::ConvBiasForward::typeinfo(), opr::ElemwiseMultiType::typeinfo(),
            opr::Elemwise::typeinfo(),        opr::TypeCvt::typeinfo(),
            opr::PoolingForward::typeinfo(),  opr::WarpPerspectiveForward::typeinfo(),
265 266
    };
    SmallVector<TensorFormats> available_tensor_formats = {
M
Megvii Engine Team 已提交
267 268 269 270 271
            TensorFormats::NCHW,    TensorFormats::NHWC,    TensorFormats::NCHWc4,
            TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4};
    Attribute attribute = {
            OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
            ReformatAttribute::AUTO_PADDING_NHWC};
272
    auto ctx = std::make_unique<LayoutTransformContext>(
M
Megvii Engine Team 已提交
273 274 275 276 277
            std::move(opr_list), std::move(available_tensor_formats), attribute);
    ctx->add_opr_config(
               opr::ConvBiasForward::typeinfo(),
               {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC,
                OprFormat::NCHW64})
278 279
            .add_opr_config(
                    opr::PoolingForward::typeinfo(),
280 281
                    {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64,
                     OprFormat::NHWC, OprFormat::CHWN4});
282 283
#if MGB_WITH_CACHED_TEST
    auto profiler = std::make_unique<ProfilerMock>(
284
            static_cast<const uint8_t*>(TestLayoutTransform_Resnet18_QS4.data()),
285 286 287 288 289
            TestLayoutTransform_Resnet18_QS4.size());
#else
    auto profiler = ProfilerBase::make_cached_profiler(
            "TestLayoutTransform.Resnet18_QS4.cache");
#endif
290 291
    std::unique_ptr<SolverBase> solver{
            new DynamicProgrammingSolver(std::move(profiler))};
M
Megvii Engine Team 已提交
292 293 294 295 296 297 298 299 300 301 302 303
    auto new_output =
            gopt::GraphOptimizer{}
                    .add_pass<FuseConvBiasNonlinPass>()
                    .add_pass<FuseConvBiasZPass>()
                    .add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
                    .add_pass<ShuffleShuffleRemovePass>()
                    .add_pass(FuseNCHW4Int8Preprocess::make())
                    .add_pass<FoldingConvBiasDimshufflePass>()
                    .add_pass<ParamFusePass>()
                    .add_pass<ParamMergePass>()
                    .apply({{output}})
                    .endpoint_vars();
304 305 306 307 308
    auto new_out_var = new_output[0];
    /// check global layout transform pass
    auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(new_out_var);
    ASSERT_EQ(nr_dimshuffle, 3u);
    /// check pass fuse conv bias with z
M
Megvii Engine Team 已提交
309
    auto nr_elemwise_mult_type = find_opr_num<opr::ElemwiseMultiType>(new_out_var);
310 311
    ASSERT_EQ(nr_elemwise_mult_type, 4u);
    /// 21 convolutions, 21 weights and 21 bias, total 42 parameters
M
Megvii Engine Team 已提交
312
    const auto& param_merge = find_opr<opr::MultipleDeviceTensorHolder>(new_out_var);
313 314 315 316 317 318 319 320 321 322
    ASSERT_EQ(param_merge.output().size(), 42u);
    /// check first conv format
    const auto& first_conv = find_opr<opr::ConvBiasForward>(new_out_var);
    const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>();
    ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NHWC);

    GraphProfiler gprof{network.graph.get()};
    HostTensorND t2;
    auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)});
    func2->execute();
M
Megvii Engine Team 已提交
323
    gprof.to_json_full(func2.get())->writeto_fpath(output_file("resnet18_qs4.json"));
324
    MGB_ASSERT_TENSOR_EQ(t1, t2);
325 326
}

327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342
TEST(TestLayoutTransform, Resnet18_NCHW64) {
    REQUIRE_GPU(1);
    auto cn = CompNode::load("gpu0");
    auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
    auto sm_ver = prop.major * 10 + prop.minor;
    if (sm_ver < 75) {
        printf("This testcast ignored due to insufficient cuda cap(got: %d, "
               "expected: %d)\n",
               sm_ver, 75);
        return;
    }
    Network network(cn);
    auto output = make_resnet18(network, 64, dtype::QuantizedS4{1.f});
    using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
    S strategy = S::PROFILE;
    gopt::modify_opr_algo_strategy_inplace({{output}}, strategy);
343

344 345 346
    HostTensorND t1;
    auto func1 = network.graph->compile({make_callback_copy(output, t1)});
    func1->execute();
347

348 349 350 351
    SymbolVar new_out_var;
    auto options = gopt::OptimizeForInferenceOptions{};
    options.enable_nchw64();
    unpack_vector(gopt::optimize_for_inference({output}, options), new_out_var);
352

353 354 355 356
    GraphProfiler gprof{network.graph.get()};
    HostTensorND t2;
    auto func2 = network.graph->compile({make_callback_copy(new_out_var, t2)});
    func2->execute();
M
Megvii Engine Team 已提交
357
    gprof.to_json_full(func2.get())->writeto_fpath(output_file("resnet18_nchw64.json"));
358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373
    MGB_ASSERT_TENSOR_EQ(t1, t2);
}

TEST(TestLayoutTransform, Detection_QS8) {
    REQUIRE_GPU(1);
    auto cn = CompNode::load("gpu0");
    auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
    auto sm_ver = prop.major * 10 + prop.minor;
    if (sm_ver < 75) {
        printf("This testcast ignored due to insufficient cuda cap(got: %d, "
               "expected: %d)\n",
               sm_ver, 75);
        return;
    }
    Network network(cn);
    auto outputs = make_det(network, 16, dtype::QuantizedS8{1.f});
374 375
    using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
    S strategy = S::PROFILE;
376
    gopt::modify_opr_algo_strategy_inplace({outputs}, strategy);
377 378 379 380

    using OprFormat = LayoutTransformContext::OprFormat;
    using OprList = LayoutTransformContext::OprList;
    using Attribute = LayoutTransformContext::Attribute;
381 382
    using Target = LayoutTransformContext::Target;
    using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
383
    OprList opr_list = {
M
Megvii Engine Team 已提交
384 385 386
            opr::ConvBiasForward::typeinfo(), opr::ElemwiseMultiType::typeinfo(),
            opr::Elemwise::typeinfo(),        opr::TypeCvt::typeinfo(),
            opr::PoolingForward::typeinfo(),  opr::WarpPerspectiveForward::typeinfo(),
387 388
    };
    SmallVector<TensorFormats> available_tensor_formats = {
M
Megvii Engine Team 已提交
389 390 391 392 393
            TensorFormats::NCHW,    TensorFormats::NHWC,    TensorFormats::NCHWc4,
            TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4};
    Attribute attribute = {
            OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
            ReformatAttribute::AUTO_PADDING_NHWC};
394
    auto ctx = std::make_unique<LayoutTransformContext>(
M
Megvii Engine Team 已提交
395 396 397 398 399
            std::move(opr_list), std::move(available_tensor_formats), attribute);
    ctx->add_opr_config(
               opr::ConvBiasForward::typeinfo(),
               {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC,
                OprFormat::NCHW64})
400 401
            .add_opr_config(
                    opr::PoolingForward::typeinfo(),
402 403
                    {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64,
                     OprFormat::NHWC, OprFormat::CHWN4});
404 405
#if MGB_WITH_CACHED_TEST
    auto profiler = std::make_unique<ProfilerMock>(
406
            static_cast<const uint8_t*>(TestLayoutTransform_Detection_QS8.data()),
407 408 409 410 411
            TestLayoutTransform_Detection_QS8.size());
#else
    auto profiler = ProfilerBase::make_cached_profiler(
            "TestLayoutTransform.Detection_QS8.cache");
#endif
412 413
    std::unique_ptr<SolverBase> solver{
            new DynamicProgrammingSolver(std::move(profiler))};
M
Megvii Engine Team 已提交
414 415 416 417 418 419 420 421 422 423 424 425
    auto new_outputs =
            gopt::GraphOptimizer{}
                    .add_pass<FuseConvBiasNonlinPass>()
                    .add_pass<FuseConvBiasZPass>()
                    .add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
                    .add_pass<ShuffleShuffleRemovePass>()
                    .add_pass(FuseNCHW4Int8Preprocess::make())
                    .add_pass<FoldingConvBiasDimshufflePass>()
                    .add_pass<ParamFusePass>()
                    .add_pass<ParamMergePass>()
                    .apply({{outputs}})
                    .endpoint_vars();
426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453

    GraphProfiler gprof{network.graph.get()};
    using OutputSpecItem = cg::ComputingGraph::OutputSpecItem;
    std::vector<OutputSpecItem> output_spec;
    for (const auto& i : new_outputs) {
        output_spec.emplace_back(OutputSpecItem{i, {}});
    }
    auto func = network.graph->compile(output_spec);
    func->execute();
    gprof.to_json_full(func.get())->writeto_fpath(output_file("det_qs8.json"));
}

TEST(TestLayoutTransform, Detection_QS4) {
    REQUIRE_GPU(1);
    auto cn = CompNode::load("gpu0");
    auto&& prop = CompNodeEnv::from_comp_node(cn).cuda_env().device_prop;
    auto sm_ver = prop.major * 10 + prop.minor;
    if (sm_ver < 75) {
        printf("This testcast ignored due to insufficient cuda cap(got: %d, "
               "expected: %d)\n",
               sm_ver, 75);
        return;
    }
    Network network(cn);
    auto outputs = make_det(network, 16, dtype::QuantizedS4{1.f});
    using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
    S strategy = S::PROFILE;
    gopt::modify_opr_algo_strategy_inplace({outputs}, strategy);
454

455 456 457 458
    using OprFormat = LayoutTransformContext::OprFormat;
    using OprList = LayoutTransformContext::OprList;
    using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
    using Attribute = LayoutTransformContext::Attribute;
459
    using Target = LayoutTransformContext::Target;
460
    OprList opr_list = {
M
Megvii Engine Team 已提交
461 462 463
            opr::ConvBiasForward::typeinfo(), opr::ElemwiseMultiType::typeinfo(),
            opr::Elemwise::typeinfo(),        opr::TypeCvt::typeinfo(),
            opr::PoolingForward::typeinfo(),  opr::WarpPerspectiveForward::typeinfo(),
464 465
    };
    SmallVector<TensorFormats> available_tensor_formats = {
M
Megvii Engine Team 已提交
466 467 468 469 470
            TensorFormats::NCHW,    TensorFormats::NHWC,    TensorFormats::NCHWc4,
            TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4};
    Attribute attribute = {
            OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
            ReformatAttribute::AUTO_PADDING_NHWC};
471
    auto ctx = std::make_unique<LayoutTransformContext>(
M
Megvii Engine Team 已提交
472 473 474 475 476
            std::move(opr_list), std::move(available_tensor_formats), attribute);
    ctx->add_opr_config(
               opr::ConvBiasForward::typeinfo(),
               {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::CHWN4, OprFormat::NHWC,
                OprFormat::NCHW64})
477 478 479 480
            .add_opr_config(
                    opr::PoolingForward::typeinfo(),
                    {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NCHW64,
                     OprFormat::NHWC, OprFormat::CHWN4});
481 482
#if MGB_WITH_CACHED_TEST
    auto profiler = std::make_unique<ProfilerMock>(
483
            static_cast<const uint8_t*>(TestLayoutTransform_Detection_QS4.data()),
484 485 486 487 488
            TestLayoutTransform_Detection_QS4.size());
#else
    auto profiler = ProfilerBase::make_cached_profiler(
            "TestLayoutTransform.Detection_QS4.cache");
#endif
489 490
    std::unique_ptr<SolverBase> solver{
            new DynamicProgrammingSolver(std::move(profiler))};
M
Megvii Engine Team 已提交
491 492 493 494 495 496 497 498 499 500 501 502
    auto new_outputs =
            gopt::GraphOptimizer{}
                    .add_pass<FuseConvBiasNonlinPass>()
                    .add_pass<FuseConvBiasZPass>()
                    .add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
                    .add_pass<ShuffleShuffleRemovePass>()
                    .add_pass(FuseNCHW4Int8Preprocess::make())
                    .add_pass<FoldingConvBiasDimshufflePass>()
                    .add_pass<ParamFusePass>()
                    .add_pass<ParamMergePass>()
                    .apply({{outputs}})
                    .endpoint_vars();
503 504

    GraphProfiler gprof{network.graph.get()};
505
    using OutputSpecItem = cg::ComputingGraph::OutputSpecItem;
506 507 508
    std::vector<OutputSpecItem> output_spec;
    for (const auto& i : new_outputs) {
        output_spec.emplace_back(OutputSpecItem{i, {}});
509
    }
510 511 512 513
    auto func = network.graph->compile(output_spec);
    func->execute();
    gprof.to_json_full(func.get())->writeto_fpath(output_file("det_qs4.json"));
}
M
Megvii Engine Team 已提交
514
#endif
515 516 517 518 519 520 521 522 523

/*!
 * test the performance of the solver when network is wide.
 */
TEST(TestLayoutTransform, Wide) {
    REQUIRE_GPU(1);
    auto cn = CompNode::load("gpu0");
    Network network(cn);
    auto data = network.add_var("data", {16, 3, 64, 64});
M
Megvii Engine Team 已提交
524
    auto f = network.add_conv(data, 16, {3, 3}, dtype::Float32(), true, {2, 2}, {1, 1});
525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544
    f = network.add_conv(f, 16, {3, 3}, dtype::Float32(), true, {2, 2}, {1, 1});
    f = network.add_conv(f, 16, {3, 3}, dtype::Float32(), true, {2, 2}, {1, 1});
    SymbolVarArray stages;
    for (size_t i = 0; i < 8; ++i) {
        f = f * f + f;
        stages.push_back(f);
    }
    auto y = stages[0];
    for (size_t i = 1; i < stages.size(); ++i) {
        y = y + stages[i];
    }

    using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
    S strategy = S::PROFILE;
    gopt::modify_opr_algo_strategy_inplace({y}, strategy);

    using OprFormat = LayoutTransformContext::OprFormat;
    using OprList = LayoutTransformContext::OprList;
    using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
    using Attribute = LayoutTransformContext::Attribute;
545
    using Target = LayoutTransformContext::Target;
546 547 548 549
    OprList opr_list = {
            opr::ConvBiasForward::typeinfo(),
            opr::Elemwise::typeinfo(),
    };
M
Megvii Engine Team 已提交
550 551 552 553 554
    SmallVector<TensorFormats> available_tensor_formats = {
            TensorFormats::NCHW, TensorFormats::NHWC};
    Attribute attribute = {
            OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
            ReformatAttribute::DEFAULT};
555
    auto ctx = std::make_unique<LayoutTransformContext>(
556 557 558
            std::move(opr_list), std::move(available_tensor_formats), attribute);
    ctx->add_opr_config(
            opr::ConvBiasForward::typeinfo(), {OprFormat::NCHW, OprFormat::NHWC});
559 560 561 562 563
#if MGB_WITH_CACHED_TEST
    auto profiler = std::make_unique<ProfilerMock>(
            static_cast<const uint8_t*>(TestLayoutTransform_Wide.data()),
            TestLayoutTransform_Wide.size());
#else
564 565
    auto profiler =
            ProfilerBase::make_cached_profiler("TestLayoutTransform.Wide.cache");
566
#endif
567 568 569 570 571
    std::unique_ptr<SolverBase> solver{
            new DynamicProgrammingSolver(std::move(profiler))};
    auto v = gopt::GraphOptimizer{}
                     .add_pass<FuseConvBiasNonlinPass>()
                     .add_pass<FuseConvBiasZPass>()
M
Megvii Engine Team 已提交
572
                     .add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592
                     .add_pass<ShuffleShuffleRemovePass>()
                     .add_pass<ParamFusePass>()
                     .add_pass<ParamMergePass>()
                     .apply({{y}})
                     .endpoint_vars();
    const auto& sym_o = v[0];
    GraphProfiler gprof{network.graph.get()};
    auto func = network.graph->compile({{sym_o, {}}});
    func->execute();
    gprof.to_json_full(func.get())->writeto_fpath(output_file("wide.json"));
    auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(sym_o);
    ASSERT_EQ(nr_dimshuffle, 0u);
    auto nr_param_merge = find_opr_num<opr::MultipleDeviceTensorHolder>(sym_o);
    ASSERT_EQ(nr_param_merge, 1u);
    /// check first conv format
    const auto& first_conv = find_opr<opr::ConvBiasForward>(sym_o);
    const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>();
    ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NCHW);
}

M
Megvii Engine Team 已提交
593
#if CUDA_VERSION >= 10020
594 595 596 597 598 599
TEST(TestLayoutTransform, DetectionHead) {
    REQUIRE_GPU(1);
    auto cn = CompNode::load("gpu0");
    cn.activate();
    REQUIRE_CUDA_COMPUTE_CAPABILITY_EQ(7, 5);

600
    constexpr size_t N = 16, C = 3, H = 736, W = 1280;
601 602 603 604 605 606 607
    HostTensorGenerator<dtype::Uint8> gen;

    auto graph = ComputingGraph::make();
    auto h2d = opr::Host2DeviceCopy::make(*graph, gen({N, C, H, W}, cn));
    auto data = opr::TypeCvt::make(h2d, dtype::Float32());
    auto sub_128 = data + (-128);
    auto x = opr::TypeCvt::make(sub_128, dtype::QuantizedS8(1.f));
M
Megvii Engine Team 已提交
608
    auto mkcvar = [&](const char* name, const TensorShape& shp, const DType& dtype) {
609
        return opr::TypeCvt::make(
M
Megvii Engine Team 已提交
610
                opr::SharedDeviceTensor::make(*graph, *gen(shp, cn)).rename(name),
611 612 613 614 615 616 617 618 619 620 621 622 623 624 625
                dtype);
    };
    auto w = mkcvar("w", {16, 3, 3, 3}, dtype::QuantizedS8(1.f));
    auto b = mkcvar("b", {1, 16, 1, 1}, dtype::QuantizedS32(1.f));
    opr::ConvBias::Param param;
    param.format = opr::ConvBias::Param::Format::NCHW;
    param.nonlineMode = opr::ConvBias::Param::NonlineMode::RELU;
    param.stride_h = param.stride_w = 2;
    param.pad_h = param.pad_w = 1;
    auto conv_1 = opr::ConvBias::make(
            x, w, b, param, {}, OperatorNodeConfig(dtype::QuantizedS8(1.f)));
    conv_1 = opr::TypeCvt::make(
            conv_1, dtype::Quantized4Asymm(1.f, static_cast<uint8_t>(8)));
    auto w1 = mkcvar("w1", {16, 16, 3, 3}, dtype::QuantizedS4(1.f));
    auto b1 = mkcvar("b1", {1, 16, 1, 1}, dtype::QuantizedS32(1.f));
M
Megvii Engine Team 已提交
626 627 628
    auto y = opr::ConvBias::make(
            conv_1, w1, b1, param, {},
            OperatorNodeConfig(dtype::Quantized4Asymm(1.f, static_cast<uint8_t>(8))));
629 630 631 632 633 634 635 636

    using S = opr::mixin::AlgoChooserHelper::ExecutionPolicy::Strategy;
    S strategy = S::PROFILE;
    gopt::modify_opr_algo_strategy_inplace({y}, strategy);

    using OprFormat = LayoutTransformContext::OprFormat;
    using OprList = LayoutTransformContext::OprList;
    using Attribute = LayoutTransformContext::Attribute;
637
    using ReformatAttribute = LayoutTransformContext::ReformatAttribute;
638
    using Target = LayoutTransformContext::Target;
639 640 641 642 643 644 645 646 647 648 649
    OprList opr_list = {
            opr::ConvBiasForward::typeinfo(),
            opr::ConvolutionForward::typeinfo(),
            opr::ConvolutionBackwardData::typeinfo(),
            opr::ElemwiseMultiType::typeinfo(),
            opr::Elemwise::typeinfo(),
            opr::TypeCvt::typeinfo(),
            opr::PoolingForward::typeinfo(),
            opr::WarpPerspectiveForward::typeinfo(),
    };
    SmallVector<TensorFormats> available_tensor_formats = {
M
Megvii Engine Team 已提交
650 651 652 653 654
            TensorFormats::NCHW,    TensorFormats::NHWC,    TensorFormats::NCHWc4,
            TensorFormats::NCHWc32, TensorFormats::NCHWc64, TensorFormats::CHWNc4};
    Attribute attribute = {
            OprFormat::NCHW, TensorFormats::NCHW, Target::UNSPEC,
            ReformatAttribute::AUTO_PADDING_NHWC};
655
    auto ctx = std::make_unique<LayoutTransformContext>(
M
Megvii Engine Team 已提交
656
            std::move(opr_list), std::move(available_tensor_formats), attribute);
657 658
    ctx->add_opr_config(
               opr::ConvBiasForward::typeinfo(),
M
Megvii Engine Team 已提交
659 660 661 662 663
               {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW32,
                OprFormat::NCHW64, OprFormat::CHWN4})
            .add_opr_config(
                    opr::ConvolutionForward::typeinfo(),
                    {OprFormat::NCHW, OprFormat::NCHW4})
664 665 666
            .add_opr_config(
                    opr::ConvolutionBackwardData::typeinfo(),
                    {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4})
667 668 669 670 671 672 673
            .add_opr_config(
                    opr::PoolingForward::typeinfo(),
                    {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC,
                     OprFormat::NCHW64, OprFormat::CHWN4})
            .add_opr_config(
                    opr::WarpPerspectiveForward::typeinfo(),
                    {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64});
674 675
#if MGB_WITH_CACHED_TEST
    auto profiler = std::make_unique<ProfilerMock>(
676
            static_cast<const uint8_t*>(TestLayoutTransform_DetectionHead.data()),
677 678 679 680 681
            TestLayoutTransform_DetectionHead.size());
#else
    auto profiler = ProfilerBase::make_cached_profiler(
            "TestLayoutTransform.DetectionHead.cache");
#endif
682 683
    std::unique_ptr<SolverBase> solver{
            new DynamicProgrammingSolver(std::move(profiler))};
M
Megvii Engine Team 已提交
684 685 686 687 688 689 690 691 692 693 694
    auto new_out_vars =
            gopt::GraphOptimizer{}
                    .add_pass<LayoutTransformPass>(std::move(ctx), std::move(solver))
                    .add_pass<ShuffleShuffleRemovePass>()
                    .add_pass(FuseNCHW4Int8Preprocess::make())
                    .add_pass<FoldingConvBiasDimshufflePass>()
                    .add_pass<FoldingConvBiasTypecvtPass>()
                    .add_pass<ParamFusePass>()
                    .add_pass<ParamMergePass>()
                    .apply(SymbolVarArray{y})
                    .endpoint_vars();
695
    const auto& v = new_out_vars[0];
696
    using OutputSpecItem = cg::ComputingGraph::OutputSpecItem;
697 698 699
    std::vector<OutputSpecItem> outs;
    for (const auto& i : new_out_vars) {
        outs.emplace_back(OutputSpecItem{i, {}});
700 701 702
    }
    GraphProfiler gprof{graph.get()};
    auto func = graph->compile(outs);
703
    func->execute();
704
    gprof.to_json_full(func.get())->writeto_fpath(output_file("det_head.json"));
705 706 707 708 709 710 711 712 713 714 715 716
    /// check reformat
    auto nr_reformat = find_opr_num<opr::RelayoutFormat>(v);
    ASSERT_EQ(nr_reformat, 2u);
    /// check dimshuffle
    auto nr_dimshuffle = find_opr_num<opr::Dimshuffle>(v);
    ASSERT_EQ(nr_dimshuffle, 0u);
    /// check conv_bias
    auto nr_conv = find_opr_num<opr::ConvBiasForward>(v);
    ASSERT_EQ(nr_conv, 2u);
    /// check first conv format
    const auto& first_conv = find_opr<opr::ConvBiasForward>(v);
    const auto& cast = first_conv.cast_final_safe<opr::ConvBiasForward>();
717 718
    ASSERT_EQ(cast.param().format, opr::ConvBias::Param::Format::NHWC);
    ASSERT_EQ(cast.output()[0]->dtype().enumv(), DTypeEnum::Quantized4Asymm);
719
}
M
Megvii Engine Team 已提交
720
#endif
721 722
#endif

723 724 725 726 727 728 729
TEST(TestLayoutTransform, CanonicalizeLayoutTransform) {
    constexpr size_t N = 64, C = 64, H = 1, W = 1;
    auto cn = CompNode::load("xpu0");
    Network network(cn);
    auto x = network.add_var("x", {N, C / 4, H, W, 4});
    x = network.add_type_cvt(x, dtype::QuantizedS4{1.f});
    using NamedTensorShape = megdnn::NamedTensorShape;
M
Megvii Engine Team 已提交
730 731 732 733
    auto src =
            NamedTensorShape::make_named_tensor_shape(NamedTensorShape::Format::NCHW4);
    auto dst =
            NamedTensorShape::make_named_tensor_shape(NamedTensorShape::Format::NHWC);
M
Megvii Engine Team 已提交
734 735
    auto&& tuple = gopt::ReformatEmitter(src, dst).emit();
    auto builder = std::get<0>(tuple);
736 737 738 739 740
    x = SymbolVar(builder({x.node()}));
    x = opr::Reshape::make(x, {N, H, W, C});
    x = network.add_type_cvt(x, dtype::Float32());

    SymbolVar another_x;
M
Megvii Engine Team 已提交
741 742 743 744 745 746
    unpack_vector(
            gopt::GraphOptimizer{}
                    .add_pass<gopt::ShuffleShuffleRemovePass>()
                    .apply({{x}})
                    .endpoint_vars(),
            another_x);
747
    const auto& astype = find_opr<opr::TypeCvt>(x);
M
Megvii Engine Team 已提交
748 749 750
    EXPECT_TRUE(
            astype.input(0)->owner_opr()->dyn_typeinfo() ==
            opr::Host2DeviceCopy::typeinfo());
751
    const auto& another_astype = find_opr<opr::TypeCvt>(another_x);
M
Megvii Engine Team 已提交
752 753 754
    EXPECT_TRUE(
            another_astype.input(0)->owner_opr()->dyn_typeinfo() ==
            opr::Reshape::typeinfo());
M
Megvii Engine Team 已提交
755 756
    size_t nr_type_cvt = find_opr_num<opr::TypeCvt>(another_x);
    ASSERT_EQ(nr_type_cvt, 2u);
757 758 759 760 761 762 763 764 765 766 767

    HostTensorND t1;
    auto func1 = network.graph->compile({make_callback_copy(x, t1)});
    func1->execute();

    HostTensorND t2;
    auto func2 = network.graph->compile({make_callback_copy(another_x, t2)});
    func2->execute();
    MGB_ASSERT_TENSOR_EQ(t1, t2);
}

768
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}