/*************************************************************************************************** * Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. * * Redistribution and use in source and binary forms, with or without *modification, are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, *this list of conditions and the following disclaimer. * * Redistributions in binary form must reproduce the above copyright *notice, this list of conditions and the following disclaimer in the *documentation and/or other materials provided with the distribution. * * Neither the name of the NVIDIA CORPORATION nor the names of its *contributors may be used to endorse or promote products derived from this *software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT, *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ /** * \file dnn/src/cuda/cutlass/operation_table.cu * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") * * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. * * Unless required by applicable law or agreed to in writing, * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or * implied. */ #include "src/common/utils.h" #include "src/cuda/cutlass/operation_table.h" ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace library { ///////////////////////////////////////////////////////////////////////////////////////////////// GemmKey get_gemm_key_from_desc(const GemmDescription& desc) { GemmKey key; key.element_A = desc.A.element; key.layout_A = desc.A.layout; key.element_B = desc.B.element; key.layout_B = desc.B.layout; key.element_C = desc.C.element; key.layout_C = desc.C.layout; key.element_accumulator = desc.tile_description.math_instruction.element_accumulator; key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); key.threadblock_shape_k = desc.tile_description.threadblock_shape.k(); key.warp_shape_m = desc.tile_description.threadblock_shape.m() / desc.tile_description.warp_count.m(); key.warp_shape_n = desc.tile_description.threadblock_shape.n() / desc.tile_description.warp_count.n(); key.warp_shape_k = desc.tile_description.threadblock_shape.k() / desc.tile_description.warp_count.k(); key.instruction_shape_m = desc.tile_description.math_instruction.instruction_shape.m(); key.instruction_shape_n = desc.tile_description.math_instruction.instruction_shape.n(); key.instruction_shape_k = desc.tile_description.math_instruction.instruction_shape.k(); key.stages = desc.stages; key.alignment_A = desc.A.alignment; key.alignment_B = desc.B.alignment; key.split_k_mode = desc.split_k_mode; return key; } ///////////////////////////////////////////////////////////////////////////////////////////////// ConvolutionKey get_convolution_key_from_desc( const ConvolutionDescription& desc) { ConvolutionKey key; key.conv_op = desc.conv_op; key.element_src = desc.src.element; key.layout_src = desc.src.layout; key.element_filter = desc.filter.element; key.layout_filter = desc.filter.layout; key.element_dst = desc.dst.element; key.layout_dst = desc.dst.layout; key.element_bias = desc.bias.element; key.layout_bias = desc.bias.layout; key.convolution_type = desc.convolution_type; key.threadblock_shape_m = desc.tile_description.threadblock_shape.m(); key.threadblock_shape_n = desc.tile_description.threadblock_shape.n(); key.threadblock_shape_k = desc.tile_description.threadblock_shape.k(); key.warp_shape_m = desc.tile_description.threadblock_shape.m() / desc.tile_description.warp_count.m(); key.warp_shape_n = desc.tile_description.threadblock_shape.n() / desc.tile_description.warp_count.n(); key.warp_shape_k = desc.tile_description.threadblock_shape.k() / desc.tile_description.warp_count.k(); key.instruction_shape_m = desc.tile_description.math_instruction.instruction_shape.m(); key.instruction_shape_n = desc.tile_description.math_instruction.instruction_shape.n(); key.instruction_shape_k = desc.tile_description.math_instruction.instruction_shape.k(); key.epilogue_type = desc.epilogue_type; key.stages = desc.tile_description.threadblock_stages; key.special_optimization = desc.special_optimization; key.without_shared_load = desc.without_shared_load; return key; } ///////////////////////////////////////////////////////////////////////////////////////////////// void OperationTable::append(Manifest const& manifest) { // Insert operations into appropriate data structure for (auto const& operation : manifest) { OperationDescription const& desc = operation->description(); // insert all gemm operations into operation table if (desc.kind == OperationKind::kGemm) { GemmKey key = get_gemm_key_from_desc( static_cast(desc)); gemm_operations[key].push_back(operation.get()); } // insert all conv operations into operation table if (desc.kind == OperationKind::kConvolution) { ConvolutionKey key = get_convolution_key_from_desc( static_cast(desc)); convolution_operations[key].push_back(operation.get()); } } } ///////////////////////////////////////////////////////////////////////////////////////////////// Operation const* OperationTable::find_op(GemmKey const& key) const { if (gemm_operations.count(key)) { auto const& ops = gemm_operations.at(key); megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", ops.size()); return ops[0]; } return nullptr; } ///////////////////////////////////////////////////////////////////////////////////////////////// Operation const* OperationTable::find_op(ConvolutionKey const& key) const { if (convolution_operations.count(key) > 0) { auto const& ops = convolution_operations.at(key); megdnn_assert(ops.size() == 1, "exactly one kernel expected, got %zu", ops.size()); return ops[0]; } return nullptr; } ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace library } // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////