layout_transform_context.cpp 4.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
/**
 * \file src/gopt/impl/layout_transform_context.cpp
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
 *
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
 * implied.
 */

#include "./utils.h"
#include "megbrain/gopt/global_layout_transform.h"
15 16 17
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/opr/imgproc.h"
#include "megbrain/opr/nn_int.h"
18 19 20 21

using namespace mgb;
using namespace gopt;

22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
namespace {
using OprFormat = LayoutTransformContext::OprFormat;
using OprList = LayoutTransformContext::OprList;
using Attribute = LayoutTransformContext::Attribute;
using Target = LayoutTransformContext::Target;
const char* target_to_string(Target target) {
#define cb(_target)       \
    case Target::_target: \
        return #_target
    switch (target) {
        cb(CUDA);
        cb(X86);
        cb(ARM);
        cb(UNSPEC);
        default:
            mgb_assert(false, "unsupported target (got:%u)",
                       static_cast<uint32_t>(target));
    }
#undef cb
}

std::unique_ptr<LayoutTransformContext> make_cuda_ctx(
        OprFormat base_opr_format, TensorFormats base_tensor_format) {
    OprList opr_list = {
            opr::ConvBiasForward::typeinfo(),
            opr::ConvolutionForward::typeinfo(),
            opr::ConvolutionBackwardData::typeinfo(),
            opr::ElemwiseMultiType::typeinfo(),
            opr::Elemwise::typeinfo(),
            opr::TypeCvt::typeinfo(),
            opr::PoolingForward::typeinfo(),
            opr::WarpPerspectiveForward::typeinfo(),
    };

    SmallVector<TensorFormats> available_tensor_formats = {
            TensorFormats::NCHW,    TensorFormats::NHWC,
            TensorFormats::NCHWc4,  TensorFormats::NCHWc32,
            TensorFormats::NCHWc64, TensorFormats::CHWNc4};
    Attribute attribute = {base_opr_format, base_tensor_format, Target::CUDA};
    auto ctx = std::make_unique<LayoutTransformContext>(
            std::move(opr_list), std::move(available_tensor_formats),
            attribute);
    ctx->add_opr_config(
               opr::ConvBiasForward::typeinfo(),
               {OprFormat::NCHW, OprFormat::NHWC, OprFormat::NCHW4,
                OprFormat::NCHW32, OprFormat::NCHW64, OprFormat::CHWN4})
            .add_opr_config(opr::ConvolutionForward::typeinfo(),
                            {OprFormat::NCHW, OprFormat::NCHW4})
            .add_opr_config(opr::ConvolutionBackwardData::typeinfo(),
                            {OprFormat::NCHW, OprFormat::NCHW4})
            .add_opr_config(
                    opr::PoolingForward::typeinfo(),
                    {OprFormat::NCHW4, OprFormat::NCHW32, OprFormat::NHWC,
                     OprFormat::NCHW64, OprFormat::CHWN4})
            .add_opr_config(
                    opr::WarpPerspectiveForward::typeinfo(),
                    {OprFormat::NHWC, OprFormat::NCHW4, OprFormat::NCHW64});
    return ctx;
}
}  // namespace

83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103
/* ================= LayoutTransformContext ==================*/
LayoutTransformContext& LayoutTransformContext::add_opr_config(
        Typeinfo* opr, OprFormat opr_format) {
    auto& dispatchers = m_opr_configs[opr];
    dispatchers[opr_format] =
            OprTensorFormatsConfiguration::find_dispatcher_by_type_format(
                    opr, opr_format);
    return *this;
}

LayoutTransformContext& LayoutTransformContext::add_opr_config(
        Typeinfo* opr, SmallVector<OprFormat> opr_formats) {
    auto& dispatchers = m_opr_configs[opr];
    for (auto opr_fmt : opr_formats) {
        dispatchers[opr_fmt] =
                OprTensorFormatsConfiguration::find_dispatcher_by_type_format(
                        opr, opr_fmt);
    }
    return *this;
}

104 105 106 107 108 109 110 111 112 113 114 115
std::unique_ptr<LayoutTransformContext> LayoutTransformContext::make(
        Target target, OprFormat base_opr_format,
        TensorFormats base_tensor_format) {
    switch (target) {
        case Target::CUDA:
            return make_cuda_ctx(base_opr_format, base_tensor_format);
        default:
            mgb_assert(false, "unsupported target %s\n",
                       target_to_string(target));
    }
}

116
// vim: syntax=cpp.doxygen