/** * \file src/gopt/impl/utils.h * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #pragma once #include "megbrain/gopt/global_layout_transform.h" namespace mgb { namespace gopt { static inline const char* opr_format_to_string( OprTensorFormatsConfiguration::OprFormat opr_format) { using OprFormat = OprTensorFormatsConfiguration::OprFormat; #define cb(_fmt) \ case OprFormat::_fmt: \ return #_fmt switch (opr_format) { cb(NCHW); cb(NHWC); cb(NCHW4); cb(NCHW32); cb(NCHW64); cb(CHWN4); default: mgb_assert(false, "Invalid opr format(got:%u)", static_cast(opr_format)); } #undef cb } static inline TensorFormats opr_format_to_tensor_formats( OprTensorFormatsConfiguration::OprFormat opr_format) { using OprFormat = OprTensorFormatsConfiguration::OprFormat; switch (opr_format) { case OprFormat::NCHW: return TensorFormats::NCHW; case OprFormat::NHWC: return TensorFormats::NHWC; case OprFormat::NCHW4: return TensorFormats::NCHWc4; case OprFormat::NCHW32: return TensorFormats::NCHWc32; case OprFormat::NCHW64: return TensorFormats::NCHWc64; case OprFormat::CHWN4: return TensorFormats::CHWNc4; default: mgb_throw(AssertionError, "format(%s) is not supported", opr_format_to_string(opr_format)); }; } static inline megdnn::NamedTensorShape tensor_formats_to_named_tensor_shape( TensorFormats format) { switch (format) { case TensorFormats::NCHW: return {{"N"}, {"C"}, {"H"}, {"W"}}; case TensorFormats::NHWC: return {{"N"}, {"H"}, {"W"}, {"C"}}; case TensorFormats::NCHWc4: return {{"N"}, {"C//4"}, {"H"}, {"W"}, {"C%4"}}; case TensorFormats::NCHWc8: return {{"N"}, {"C//8"}, {"H"}, {"W"}, {"C%8"}}; case TensorFormats::NCHWc32: return {{"N"}, {"C//32"}, {"H"}, {"W"}, {"C%32"}}; case TensorFormats::NCHWc64: return {{"N"}, {"C//64"}, {"H"}, {"W"}, {"C%64"}}; case TensorFormats::CHWNc4: return {{"C//4"}, {"H"}, {"W"}, {"N"}, {"C%4"}}; case TensorFormats::NHCWc4: return {{"N"}, {"H"}, {"C//4"}, {"W"}, {"C%4"}}; case TensorFormats::KRSCk4: return {{"K//4"}, {"R"}, {"S"}, {"C"}, {"K%4"}}; case TensorFormats::GKRSCk4: return {{"G"}, {"K//4"}, {"R"}, {"S"}, {"C"}, {"K%4"}}; case TensorFormats::C1RSc4: return {{"C//4"}, {"C%1"}, {"R"}, {"S"}, {"C%4"}}; case TensorFormats::KRSCk4c4: return {{"K//4"}, {"R"}, {"S"}, {"C//4"}, {"K%4"}, {"C%4"}}; case TensorFormats::GKRSCk4c4: return {{"G"}, {"K//4"}, {"R"}, {"S"}, {"C//4"}, {"K%4"}, {"C%4"}}; case TensorFormats::KCRSk4c4: return {{"K//4"}, {"C//4"}, {"R"}, {"S"}, {"K%4"}, {"C%4"}}; case TensorFormats::GKCRSk4c4: return {{"G"}, {"K//4"}, {"C//4"}, {"R"}, {"S"}, {"K%4"}, {"C%4"}}; case TensorFormats::KCRSc4k4: return {{"K//4"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}, {"K%4"}}; case TensorFormats::GKCRSc4k4: return {{"G"}, {"K//4"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}, {"K%4"}}; case TensorFormats::C11RSc4: return {{"C//4"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%4"}}; case TensorFormats::KCRSc8k8: return {{"K//8"}, {"C//8"}, {"R"}, {"S"}, {"C%8"}, {"K%8"}}; case TensorFormats::GKCRSc8k8: return {{"G"}, {"K//8"}, {"C//8"}, {"R"}, {"S"}, {"C%8"}, {"K%8"}}; case TensorFormats::C11RSc8: return {{"C//8"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}, {"C%8"}}; case TensorFormats::KRSCk8: return {{"K//8"}, {"R"}, {"S"}, {"C"}, {"K%8"}}; case TensorFormats::KCRSc4: return {{"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; case TensorFormats::GKCRSc4: return {{"G"}, {"K"}, {"C//4"}, {"R"}, {"S"}, {"C%4"}}; case TensorFormats::KCRS: return {{"K"}, {"C"}, {"R"}, {"S"}}; case TensorFormats::GKCRS: return {{"G"}, {"K"}, {"C"}, {"R"}, {"S"}}; case TensorFormats::C11RS: return {{"C"}, {"C%1"}, {"C%1"}, {"R"}, {"S"}}; default: mgb_throw(AssertionError, "invalid tensor formats(%u)", static_cast(format)); } } } // namespace gopt } // namespace mgb // vim: syntax=cpp.doxygen