From 4e9e23cbcbf1c5c96c76561e45a8b3fd66bc7b6b Mon Sep 17 00:00:00 2001 From: zhoutianzi666 <39978853+zhoutianzi666@users.noreply.github.com> Date: Mon, 13 Mar 2023 10:45:47 +0800 Subject: [PATCH] [Paddle Inference ]use python to generate cutlass code (#50603) * use python to generate cutlass code * refine CommonConvKernelPart1, CommonConvKernelPart2 * remove useless code in generate_cutlass_code.sh * add more config in conv2d_residual * CommonCutlassConvKernelPart1 and CommonCutlassConvKernelPart2 * add group conv support in util.cu * remove .sh * refine name * make name goodgit status! * add fuse_alpha * make code easy to understand * mot fopen generate in py * use python script to generate conv2d,group=1 cutlass code * use const & * use const & && use python script to generate conv2d/group=1 code --- .gitignore | 1 + .../ir/conv2d_fusion_layout_transfer_pass.cc | 4 + paddle/phi/kernels/CMakeLists.txt | 10 +- .../fusion/cutlass/conv2d/conv2d_bias.cu | 225 ---------------- .../fusion/cutlass/conv2d/conv2d_bias_act.py | 222 ++++++++++++++++ .../cutlass/conv2d/conv2d_bias_add_relu.cu | 248 ------------------ .../cutlass/conv2d/conv2d_bias_leaky_relu.cu | 226 ---------------- .../fusion/cutlass/conv2d/conv2d_bias_relu.cu | 225 ---------------- .../cutlass/conv2d/conv2d_bias_residual.py | 214 +++++++++++++++ .../fusion/cutlass/conv2d/conv2d_bias_silu.cu | 226 ---------------- .../fusion/cutlass/conv2d/conv2d_common.py | 204 ++++++++++++++ .../fusion/cutlass/conv2d/conv2d_decl.h | 12 +- .../fusion/cutlass/conv2d/conv2d_util.cu | 26 +- .../fusion/cutlass/conv2d/conv2d_util.h | 8 +- .../kernels/fusion/cutlass/conv2d_fusion.cu | 5 +- paddle/phi/kernels/fusion/cutlass/util.py | 37 +++ .../test_cutlass_conv2d_fusion_op.py | 6 +- 17 files changed, 727 insertions(+), 1172 deletions(-) delete mode 100644 paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias.cu create mode 100644 paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_act.py delete mode 100644 paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_add_relu.cu delete mode 100644 paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_leaky_relu.cu delete mode 100644 paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_relu.cu create mode 100644 paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_residual.py delete mode 100644 paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_silu.cu create mode 100644 paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py create mode 100644 paddle/phi/kernels/fusion/cutlass/util.py diff --git a/.gitignore b/.gitignore index 18f857f81b..f48df37881 100644 --- a/.gitignore +++ b/.gitignore @@ -96,4 +96,5 @@ paddle/fluid/prim/api/generated/prim_api/* paddle/fluid/framework/__init__.py paddle/phi/api/profiler/__init__.py python/paddle/incubate/fleet/parameter_server/pslib/ps_pb2.py +paddle/phi/kernels/fusion/cutlass/conv2d/generated/* python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py diff --git a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc index 31041754b0..9a82244b44 100644 --- a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc +++ b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc @@ -207,6 +207,10 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { op_desc->SetAttr("data_format", std::string{"NHWC"}); if (cutlass_enable && CutlassIsValid(op_node)) { op_desc->SetType("conv2d_fusion_cutlass"); + // conv2d_fusion_cutlass must have this attribute because of signature. + if (!op_desc->HasAttr("fuse_alpha")) { + op_desc->SetAttr("fuse_alpha", 0.f); + } } op_desc->Flush(); diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 5eed394a5f..b873e531b6 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -114,7 +114,15 @@ file( "fusion/gpu/*.cu") if(WITH_CUTLASS) - file(GLOB cutlass_cu "fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu") + execute_process( + COMMAND ${CMAKE_COMMAND} -E make_directory + "${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/conv2d/generated" + COMMAND ${PYTHON_EXECUTABLE} "conv2d_bias_act.py" + COMMAND ${PYTHON_EXECUTABLE} "conv2d_bias_residual.py" + WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/conv2d") + + file(GLOB cutlass_cu "fusion/cutlass/conv2d/generated/*.cu" + "fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu") list(APPEND kernel_cu ${cutlass_cu}) endif() diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias.cu deleted file mode 100644 index 234f2437cd..0000000000 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias.cu +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "cutlass/conv/kernel/default_conv2d_fprop.h" -#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" - -namespace phi { -namespace fusion { -namespace cutlass_internal { - -template -cutlass::Status Conv2dBiasImpl(ConvAllParams params) { - using ElementAccumulator = float; - using ElementComputeEpilogue = float; - using ElementInputA = cutlass::half_t; - using ElementInputB = cutlass::half_t; - using ElementOutput = cutlass::half_t; - using LayoutInputA = cutlass::layout::TensorNHWC; - using LayoutInputB = cutlass::layout::TensorNHWC; - using LayoutOutput = cutlass::layout::TensorNHWC; - using MMAOp = cutlass::arch::OpClassTensorOp; - using SmArch = cutlass::arch::Sm75; - using ThreadblockShape = TShape; - using WarpShape = WShape; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; - using SwizzleThreadBlock = - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; - constexpr int NumStages = 2; - static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = - cutlass::conv::IteratorAlgorithm::kOptimized; - using EpilogueOp = - cutlass::epilogue::thread::LinearCombination; - - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementInputA, - LayoutInputA, - ElementInputB, - LayoutInputB, - ElementOutput, - LayoutOutput, - ElementAccumulator, - MMAOp, - SmArch, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOp, - SwizzleThreadBlock, - NumStages, - cutlass::arch::OpMultiplyAdd, - IteratorAlgorithm, - cutlass::conv::StrideSupport::kStrided, - Alignment, - Alignment>::Kernel; - using ImplicitGemm = - cutlass::conv::device::ImplicitGemmConvolution; - - const half *input = params.input; - const half *weight = params.weight; - const half *bias = params.bias; - half *output = params.output; - int batch = params.batch; - int ic = params.ic; - int ih = params.ih; - int iw = params.iw; - int kh = params.kh; - int kw = params.kw; - int oc = params.oc; - int pad_h0 = params.pad_h0; - int pad_w0 = params.pad_w0; - int stride_h = params.stride_h; - int stride_w = params.stride_w; - - int oh = params.oh; - int ow = params.ow; - int dilation_h = params.dilation_h; - int dilation_w = params.dilation_w; - - cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; - cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic}, - {oc, kh, kw, ic}, - {pad_h0, 0, pad_w0, 0}, - {stride_h, stride_w}, - {dilation_h, dilation_w}, - {batch, oh, ow, oc}, - mode, - 1); - - typename ImplicitGemm::Arguments arguments{ - problem_size, - {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}}, - {(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}}, - {(cutlass::half_t *)(bias), {0, 0, 0}}, - {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}}, - {1.f, 1.f}}; - - ImplicitGemm implicit_gemm_op; - size_t bytes = implicit_gemm_op.get_workspace_size(arguments); - - auto ctx = params.ctx; - auto stream = ctx->stream(); - phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc( - ctx->GetPlace(), - bytes, - phi::Stream(reinterpret_cast(stream))); - void *workspace = tmp_gpu_ptrs_data->ptr(); - - cutlass::Status status = implicit_gemm_op.can_implement(arguments); - CUTLASS_CHECK(status); - status = implicit_gemm_op.initialize(arguments, workspace); - CUTLASS_CHECK(status); - status = implicit_gemm_op(stream); - CUTLASS_CHECK(status); - return status; -} - -// config 0 -template cutlass::Status Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 32, 64>>( - ConvAllParams); -// config 1 -template cutlass::Status Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 32, 64>>( - ConvAllParams); -// config 2 -template cutlass::Status Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 32, 64>>( - ConvAllParams); -// config 3 -template cutlass::Status Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 32, 64>>( - ConvAllParams); -// config 4 -template cutlass::Status Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 32, 32>>( - ConvAllParams); -// config 5 -template cutlass::Status Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 64, 32>>( - ConvAllParams); -// config 6 -template cutlass::Status Conv2dBiasImpl, - cutlass::gemm::GemmShape<64, 64, 32>>( - ConvAllParams); -// config 7 -template cutlass::Status Conv2dBiasImpl, - cutlass::gemm::GemmShape<64, 64, 32>>( - ConvAllParams); -// config 8 -template cutlass::Status Conv2dBiasImpl, - cutlass::gemm::GemmShape<64, 32, 32>>( - ConvAllParams); - -std::vector> - conv2d_bias_all_func = { - Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 32, 32>>, - Conv2dBiasImpl, - cutlass::gemm::GemmShape<32, 64, 32>>, - Conv2dBiasImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasImpl, - cutlass::gemm::GemmShape<64, 32, 32>>}; - -std::map, int> map_problem_conv2d_bias; -std::mutex conv2d_bias_mutex; - -void Conv2dBias(ConvAllParams params) { - int batch = params.batch; - int ic = params.ic; - int ih = params.ih; - int iw = params.iw; - int kh = params.kh; - int kw = params.kw; - int oc = params.oc; - int pad_h0 = params.pad_h0; - int pad_w0 = params.pad_w0; - int stride_h = params.stride_h; - int stride_w = params.stride_w; - - std::vector problem_size = { - batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w}; - - if (map_problem_conv2d_bias.count(problem_size)) { - conv2d_bias_all_func[map_problem_conv2d_bias.at(problem_size)](params); - return; - } - - int best_config_index = - ProfileToGetBestConfig(conv2d_bias_all_func, params, CONV2D_BIAS); - - std::lock_guard guard(conv2d_bias_mutex); - map_problem_conv2d_bias[problem_size] = best_config_index; - conv2d_bias_all_func[best_config_index](params); -} -} // namespace cutlass_internal -} // namespace fusion -} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_act.py b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_act.py new file mode 100644 index 0000000000..86e863cb4c --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_act.py @@ -0,0 +1,222 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +sys.path.append("../") +import enum + +from conv2d_common import ( + CommonConvFunction, + CommonCutlassConvKernelDeclare, + CommonCutlassConvKernelExecute, + CommonTail, + GenerateFunctionForPhi, +) +from util import SubstituteTemplate, TileDesc + +# this is a file's header part + +cba_header = ''' +// Generated by conv2d_bias_act.py - Do not edit. + +#include +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/epilogue/thread/linear_combination_leaky_relu.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" +#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { +''' + +# This is a cutlass kernel, will be many these like kernels + +dict_for_declare_part = { + "conv_kind_name": "Fprop", + "epi_part": "${epi_func}< ${element_c}, ${epilogue_vector_length}, ${element_accum}, ${element_epilogue}>", +} + +cba_kernel_no_alpha = ( + SubstituteTemplate(CommonCutlassConvKernelDeclare, dict_for_declare_part) + + ''' + typename ImplicitGemm::Arguments arguments{ + problem_size, + {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}}, + {(cutlass::half_t *)(weight), {kc, kc * kw, kc * kw * kh}}, + {(cutlass::half_t *)(bias), {0, 0, 0}}, + {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}}, + {1.f, 1.f}}; +''' + + CommonCutlassConvKernelExecute +) + +# This is used for leaky_relu, this activation need a fuse_alpha parameter. + +cba_kernel_alpha = cba_kernel_no_alpha.replace( + "{1.f, 1.f}", "{1.f, 1.f, alpha}" +).replace( + "typename ImplicitG", "float alpha = params.alpha; typename ImplicitG" +) + + +class CbaAct(enum.Enum): + Identity = 1 + Relu = 2 + Silu = 3 + LeakyRelu = 4 + + +# Some global variables used, now we only support these activations. +SupportedAct = [ + CbaAct.Identity, + CbaAct.Relu, + CbaAct.Silu, + CbaAct.LeakyRelu, +] + +ActTag = { + SupportedAct[0]: 'cutlass::epilogue::thread::LinearCombination', + SupportedAct[1]: 'cutlass::epilogue::thread::LinearCombinationRelu', + SupportedAct[2]: 'cutlass::epilogue::thread::LinearCombinationSilu', + SupportedAct[3]: 'cutlass::epilogue::thread::LinearCombinationLeakyRelu', +} + +UnderScoreName = { + SupportedAct[0]: "conv2d_bias", + SupportedAct[1]: "conv2d_bias_relu", + SupportedAct[2]: "conv2d_bias_silu", + SupportedAct[3]: "conv2d_bias_leaky_relu", +} + +CamelName = { + SupportedAct[0]: "Conv2dBias", + SupportedAct[1]: "Conv2dBiasRelu", + SupportedAct[2]: "Conv2dBiasSilu", + SupportedAct[3]: "Conv2dBiasLeakyRelu", +} + +# Generate sm75 TensorOp conv code. +# CUTLASS Tensor Core operations are implemented using CUDA's mma instruction. +# Here is mma.m16n8k8. + + +def generate_sm75_1688(): + kernel_dict = { + "element_a": "cutlass::half_t", + "layout_a": "cutlass::layout::TensorNHWC", + "element_b": "cutlass::half_t", + "layout_b": "cutlass::layout::TensorNHWC", + "element_c": "cutlass::half_t", + "layout_c": "cutlass::layout::TensorNHWC", + "opcode_class": "cutlass::arch::OpClassTensorOp", + "arch": "cutlass::arch::Sm75", + "stages": "2", + "swizzling_functor": "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", + # alpha is always float! + "element_epilogue": "float", + "math_operator": "cutlass::arch::OpMultiplyAdd", + } + + kernel_dict["stride_support"] = "cutlass::conv::StrideSupport::kStrided" + + # iterate over this loop + element_accums = ["cutlass::half_t", "float"] + iterator_algorithms = [ + "cutlass::conv::IteratorAlgorithm::kOptimized", + # "cutlass::conv::IteratorAlgorithm::kAnalytic", + ] + + math_instructions = [ + ( + "16,8,8", + "cutlass::half_t", + "cutlass::half_t", + "cutlass::half_t", + ), + ( + "16,8,8", + "cutlass::half_t", + "cutlass::half_t", + "float", + ), + ] + + alignments = [8] + + kernel_dict["align_a"] = "8" + kernel_dict["align_b"] = "8" + # this should divided by oc + kernel_dict["epilogue_vector_length"] = "8" + + sm75_code = "" + for epi_func in SupportedAct: + op_dict = {} + op_dict["func_name"] = UnderScoreName[epi_func].lower() + "_sm75" + op_dict["enum_op_name"] = UnderScoreName[epi_func].upper() + # For a function, we record all its kernels into a std::vector in C++ code + all_kernel_names = "" + kernel_dict["epi_func"] = ActTag[epi_func] + suffix = 0 + for iterator_algorithm in iterator_algorithms: + for alignment in alignments: + for math_inst in math_instructions: + tiles = [ + TileDesc("64, 64, 64", 2, "32, 32, 64", math_inst), + TileDesc("64, 32, 64", 2, "32, 32, 64", math_inst), + TileDesc("128, 32, 64", 2, "32, 32, 64", math_inst), + TileDesc("128, 64, 64", 2, "32, 32, 64", math_inst), + TileDesc("64, 64, 32", 2, "32, 32, 32", math_inst), + TileDesc("64, 128, 32", 2, "32, 64, 32", math_inst), + TileDesc("64, 128, 64", 2, "64, 64, 32", math_inst), + TileDesc("64, 256, 32", 2, "64, 64, 32", math_inst), + TileDesc("128, 64, 32", 2, "64, 32, 32", math_inst), + ] + for tile in tiles: + kernel_dict["iterator_algorithm"] = iterator_algorithm + kernel_dict["Tshape"] = tile.Tshape + kernel_dict["Wshape"] = tile.Wshape + kernel_dict["Ishape"] = tile.math_inst[0] + kernel_dict["element_accum"] = tile.math_inst[3] + kernel_dict["kernel_func_name"] = op_dict[ + "func_name" + ] + str(suffix) + suffix += 1 + cba_kernel = cba_kernel_no_alpha + if epi_func in [CbaAct.LeakyRelu]: + cba_kernel = cba_kernel_alpha + sm75_code += SubstituteTemplate(cba_kernel, kernel_dict) + all_kernel_names += ( + kernel_dict["kernel_func_name"] + ", \n" + ) + + # Generate op code + op_dict["all_kernel_func_name"] = all_kernel_names + sm75_code += SubstituteTemplate(CommonConvFunction, op_dict) + return sm75_code + + +if __name__ == "__main__": + sm_versions = ["75"] + all_code = cba_header + all_code += generate_sm75_1688() + all_code += GenerateFunctionForPhi( + sm_versions, SupportedAct, UnderScoreName, CamelName + ) + all_code += CommonTail + with open("generated/conv2d_bias_act.cu", "w") as f: + f.write(all_code) + f.close() diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_add_relu.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_add_relu.cu deleted file mode 100644 index d69c0bda61..0000000000 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_add_relu.cu +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h" -#include "cutlass/epilogue/thread/linear_combination_residual_block.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" - -namespace phi { -namespace fusion { - -namespace cutlass_internal { - -template -cutlass::Status Conv2dBiasAddReluImpl(ConvAllParams params) { - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< - cutlass::half_t, - float, - float, - cutlass::half_t, - Alignment, - cutlass::epilogue::thread::Identity, - cutlass::plus, - cutlass::epilogue::thread::ReLu>; - - using Conv2dFpropKernel = - typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast< - cutlass::half_t, - cutlass::layout::TensorNHWC, - cutlass::half_t, - cutlass::layout::TensorNHWC, - cutlass::half_t, - cutlass::layout::TensorNHWC, - float, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm75, - TShape, - WShape, - cutlass::gemm::GemmShape<16, 8, 8>, - EpilogueOp, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, - 2, - cutlass::arch::OpMultiplyAdd, - cutlass::conv::IteratorAlgorithm::kOptimized, - cutlass::conv::StrideSupport::kStrided, - Alignment, - Alignment>::Kernel; - - using ImplicitGemm = - cutlass::conv::device::ImplicitGemmConvolution; - - const half *input = params.input; - const half *weight = params.weight; - const half *bias = params.bias; - - half *output = params.output; - int batch = params.batch; - int ic = params.ic; - int ih = params.ih; - int iw = params.iw; - int kh = params.kh; - int kw = params.kw; - int oc = params.oc; - int pad_h0 = params.pad_h0; - int pad_w0 = params.pad_w0; - int stride_h = params.stride_h; - int stride_w = params.stride_w; - const half *residual = params.residual; - - int oh = params.oh; - int ow = params.ow; - int dilation_h = params.dilation_h; - int dilation_w = params.dilation_w; - - cutlass::conv::Conv2dProblemSize problem_size( - {batch, ih, iw, ic}, - {oc, kh, kw, ic}, - {pad_h0, 0, pad_w0, 0}, - {stride_h, stride_w}, - {dilation_h, dilation_w}, - {batch, oh, ow, oc}, - cutlass::conv::Mode::kCrossCorrelation, - 1); - - typename ImplicitGemm::Arguments arguments{ - problem_size, - {(cutlass::half_t *)input, {ic, ic * iw, ic * iw * ih}}, - {(cutlass::half_t *)weight, {ic, ic * kw, ic * kw * kh}}, - {(cutlass::half_t *)residual, {oc, oc * ow, oc * ow * oh}}, - {(cutlass::half_t *)output, {oc, oc * ow, oc * ow * oh}}, - {1.f, 1.f}, - cutlass::conv::SplitKMode::kSerial, - (cutlass::half_t *)(bias), - nullptr, - 0, - oc}; - - ImplicitGemm implicit_gemm_op; - size_t bytes = implicit_gemm_op.get_workspace_size(arguments); - - auto ctx = params.ctx; - auto stream = ctx->stream(); - phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc( - ctx->GetPlace(), - bytes, - phi::Stream(reinterpret_cast(stream))); - void *workspace = tmp_gpu_ptrs_data->ptr(); - - cutlass::Status status = implicit_gemm_op.can_implement(arguments); - CUTLASS_CHECK(status); - status = implicit_gemm_op.initialize(arguments, workspace); - CUTLASS_CHECK(status); - status = implicit_gemm_op(stream); - CUTLASS_CHECK(status); - return status; -} - -// config 0 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 1 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 2 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 3 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 4 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams); -// config 5 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams); -// config 6 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); -// config 7 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); -// config 8 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams); -// config 9 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); -// config 10 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); -// config 11 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); -// config 12 -template cutlass::Status - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); - -std::vector> - conv2d_bias_add_relu_all_func = { - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 32, 32>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<32, 64, 32>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 32, 32>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasAddReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>}; -std::map, int> map_problem_conv2d_bias_add_relu; -std::mutex conv2d_bias_add_relu_mutex; - -void Conv2dBiasAddRelu(ConvAllParams params) { - int batch = params.batch; - int ic = params.ic; - int ih = params.ih; - int iw = params.iw; - int kh = params.kh; - int kw = params.kw; - int oc = params.oc; - int pad_h0 = params.pad_h0; - int pad_w0 = params.pad_w0; - int stride_h = params.stride_h; - int stride_w = params.stride_w; - - std::vector problem_size = { - batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w}; - - if (map_problem_conv2d_bias_add_relu.count(problem_size)) { - conv2d_bias_add_relu_all_func[map_problem_conv2d_bias_add_relu.at( - problem_size)](params); - return; - } - - std::lock_guard guard(conv2d_bias_add_relu_mutex); - - // config 6's diff is large. - conv2d_bias_add_relu_all_func[6] = nullptr; - - int best_config_index = ProfileToGetBestConfig( - conv2d_bias_add_relu_all_func, params, CONV2D_BIAS_ADD_RELU); - map_problem_conv2d_bias_add_relu[problem_size] = best_config_index; - conv2d_bias_add_relu_all_func[best_config_index](params); -} -} // namespace cutlass_internal -} // namespace fusion -} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_leaky_relu.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_leaky_relu.cu deleted file mode 100644 index 10bd7a2352..0000000000 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_leaky_relu.cu +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "cutlass/conv/kernel/default_conv2d_fprop.h" -#include "cutlass/epilogue/thread/linear_combination_leaky_relu.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" - -namespace phi { -namespace fusion { -namespace cutlass_internal { -template -cutlass::Status Conv2dBiasLeakyReluImpl(ConvAllParams params) { - using ElementAccumulator = float; - using ElementComputeEpilogue = float; - using ElementInputA = cutlass::half_t; - using ElementInputB = cutlass::half_t; - using ElementOutput = cutlass::half_t; - using LayoutInputA = cutlass::layout::TensorNHWC; - using LayoutInputB = cutlass::layout::TensorNHWC; - using LayoutOutput = cutlass::layout::TensorNHWC; - using MMAOp = cutlass::arch::OpClassTensorOp; - using SmArch = cutlass::arch::Sm75; - using ThreadblockShape = TShape; - using WarpShape = WShape; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; - using SwizzleThreadBlock = - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; - constexpr int NumStages = 2; - static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = - cutlass::conv::IteratorAlgorithm::kOptimized; - using EpilogueOp = cutlass::epilogue::thread::LinearCombinationLeakyRelu< - ElementOutput, - Alignment, - float, - ElementComputeEpilogue>; - - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementInputA, - LayoutInputA, - ElementInputB, - LayoutInputB, - ElementOutput, - LayoutOutput, - ElementAccumulator, - MMAOp, - SmArch, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOp, - SwizzleThreadBlock, - NumStages, - cutlass::arch::OpMultiplyAdd, - IteratorAlgorithm, - cutlass::conv::StrideSupport::kStrided, - Alignment, - Alignment>::Kernel; - using ImplicitGemm = - cutlass::conv::device::ImplicitGemmConvolution; - - const half *input = params.input; - const half *weight = params.weight; - const half *bias = params.bias; - half *output = params.output; - int batch = params.batch; - int ic = params.ic; - int ih = params.ih; - int iw = params.iw; - int kh = params.kh; - int kw = params.kw; - int oc = params.oc; - int pad_h0 = params.pad_h0; - int pad_w0 = params.pad_w0; - int stride_h = params.stride_h; - int stride_w = params.stride_w; - float alpha = params.alpha; - - int oh = params.oh; - int ow = params.ow; - int dilation_h = params.dilation_h; - int dilation_w = params.dilation_w; - - cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; - cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic}, - {oc, kh, kw, ic}, - {pad_h0, 0, pad_w0, 0}, - {stride_h, stride_w}, - {dilation_h, dilation_w}, - {batch, oh, ow, oc}, - mode, - 1); - - typename ImplicitGemm::Arguments arguments{ - problem_size, - {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}}, - {(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}}, - {(cutlass::half_t *)(bias), {0, 0, 0}}, - {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}}, - {1.f, 1.f, alpha}}; - - ImplicitGemm implicit_gemm_op; - size_t bytes = implicit_gemm_op.get_workspace_size(arguments); - - auto ctx = params.ctx; - auto stream = ctx->stream(); - phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc( - ctx->GetPlace(), - bytes, - phi::Stream(reinterpret_cast(stream))); - void *workspace = tmp_gpu_ptrs_data->ptr(); - - cutlass::Status status = implicit_gemm_op.can_implement(arguments); - CUTLASS_CHECK(status); - status = implicit_gemm_op.initialize(arguments, workspace); - CUTLASS_CHECK(status); - status = implicit_gemm_op(stream); - CUTLASS_CHECK(status); - return status; -} - -// config 0 -template cutlass::Status Conv2dBiasLeakyReluImpl< - cutlass::gemm::GemmShape<64, 64, 64>, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 1 -template cutlass::Status Conv2dBiasLeakyReluImpl< - cutlass::gemm::GemmShape<64, 32, 64>, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 2 -template cutlass::Status Conv2dBiasLeakyReluImpl< - cutlass::gemm::GemmShape<128, 32, 64>, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 3 -template cutlass::Status Conv2dBiasLeakyReluImpl< - cutlass::gemm::GemmShape<128, 64, 64>, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 4 -template cutlass::Status Conv2dBiasLeakyReluImpl< - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams); -// config 5 -template cutlass::Status Conv2dBiasLeakyReluImpl< - cutlass::gemm::GemmShape<64, 128, 32>, - cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams); -// config 6 -template cutlass::Status Conv2dBiasLeakyReluImpl< - cutlass::gemm::GemmShape<64, 128, 64>, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); -// config 7 -template cutlass::Status Conv2dBiasLeakyReluImpl< - cutlass::gemm::GemmShape<64, 256, 32>, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); -// config 8 -template cutlass::Status Conv2dBiasLeakyReluImpl< - cutlass::gemm::GemmShape<128, 64, 32>, - cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams); - -std::vector> - conv2d_bias_leaky_relu_all_func = { - Conv2dBiasLeakyReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasLeakyReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasLeakyReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasLeakyReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasLeakyReluImpl, - cutlass::gemm::GemmShape<32, 32, 32>>, - Conv2dBiasLeakyReluImpl, - cutlass::gemm::GemmShape<32, 64, 32>>, - Conv2dBiasLeakyReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasLeakyReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasLeakyReluImpl, - cutlass::gemm::GemmShape<64, 32, 32>>}; - -std::map, int> map_problem_conv2d_bias_leaky_relu; -std::mutex conv2d_bias_leaky_relu_mutex; - -void Conv2dBiasLeakyRelu(ConvAllParams params) { - int batch = params.batch; - int ic = params.ic; - int ih = params.ih; - int iw = params.iw; - int kh = params.kh; - int kw = params.kw; - int oc = params.oc; - int pad_h0 = params.pad_h0; - int pad_w0 = params.pad_w0; - int stride_h = params.stride_h; - int stride_w = params.stride_w; - - std::vector problem_size = { - batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w}; - - if (map_problem_conv2d_bias_leaky_relu.count(problem_size)) { - conv2d_bias_leaky_relu_all_func[map_problem_conv2d_bias_leaky_relu.at( - problem_size)](params); - return; - } - - int best_config_index = ProfileToGetBestConfig( - conv2d_bias_leaky_relu_all_func, params, CONV2D_BIAS_LEAKY_RELU); - - std::lock_guard guard(conv2d_bias_leaky_relu_mutex); - map_problem_conv2d_bias_leaky_relu[problem_size] = best_config_index; - conv2d_bias_leaky_relu_all_func[best_config_index](params); -} -} // namespace cutlass_internal -} // namespace fusion -} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_relu.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_relu.cu deleted file mode 100644 index 22863e026b..0000000000 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_relu.cu +++ /dev/null @@ -1,225 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "cutlass/conv/kernel/default_conv2d_fprop.h" -#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" - -namespace phi { -namespace fusion { -namespace cutlass_internal { -template -cutlass::Status Conv2dBiasReluImpl(ConvAllParams params) { - using ElementAccumulator = float; - using ElementComputeEpilogue = float; - using ElementInputA = cutlass::half_t; - using ElementInputB = cutlass::half_t; - using ElementOutput = cutlass::half_t; - using LayoutInputA = cutlass::layout::TensorNHWC; - using LayoutInputB = cutlass::layout::TensorNHWC; - using LayoutOutput = cutlass::layout::TensorNHWC; - using MMAOp = cutlass::arch::OpClassTensorOp; - using SmArch = cutlass::arch::Sm75; - using ThreadblockShape = TShape; - using WarpShape = WShape; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; - using SwizzleThreadBlock = - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; - constexpr int NumStages = 2; - static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = - cutlass::conv::IteratorAlgorithm::kOptimized; - using EpilogueOp = - cutlass::epilogue::thread::LinearCombinationRelu; - - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementInputA, - LayoutInputA, - ElementInputB, - LayoutInputB, - ElementOutput, - LayoutOutput, - ElementAccumulator, - MMAOp, - SmArch, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOp, - SwizzleThreadBlock, - NumStages, - cutlass::arch::OpMultiplyAdd, - IteratorAlgorithm, - cutlass::conv::StrideSupport::kStrided, - Alignment, - Alignment>::Kernel; - using ImplicitGemm = - cutlass::conv::device::ImplicitGemmConvolution; - - const half *input = params.input; - const half *weight = params.weight; - const half *bias = params.bias; - half *output = params.output; - int batch = params.batch; - int ic = params.ic; - int ih = params.ih; - int iw = params.iw; - int kh = params.kh; - int kw = params.kw; - int oc = params.oc; - int pad_h0 = params.pad_h0; - int pad_w0 = params.pad_w0; - - int stride_h = params.stride_h; - int stride_w = params.stride_w; - int oh = params.oh; - int ow = params.ow; - int dilation_h = params.dilation_h; - int dilation_w = params.dilation_w; - - cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; - cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic}, - {oc, kh, kw, ic}, - {pad_h0, 0, pad_w0, 0}, - {stride_h, stride_w}, - {dilation_h, dilation_w}, - {batch, oh, ow, oc}, - mode, - 1); - - typename ImplicitGemm::Arguments arguments{ - problem_size, - {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}}, - {(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}}, - {(cutlass::half_t *)(bias), {0, 0, 0}}, - {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}}, - {1.f, 1.f}}; - - ImplicitGemm implicit_gemm_op; - size_t bytes = implicit_gemm_op.get_workspace_size(arguments); - - auto ctx = params.ctx; - auto stream = ctx->stream(); - phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc( - ctx->GetPlace(), - bytes, - phi::Stream(reinterpret_cast(stream))); - void *workspace = tmp_gpu_ptrs_data->ptr(); - - cutlass::Status status = implicit_gemm_op.can_implement(arguments); - CUTLASS_CHECK(status); - status = implicit_gemm_op.initialize(arguments, workspace); - CUTLASS_CHECK(status); - status = implicit_gemm_op(stream); - CUTLASS_CHECK(status); - return status; -} - -// config 0 -template cutlass::Status - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 1 -template cutlass::Status - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 2 -template cutlass::Status - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 3 -template cutlass::Status - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 4 -template cutlass::Status - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams); -// config 5 -template cutlass::Status - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams); -// config 6 -template cutlass::Status - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); -// config 7 -template cutlass::Status - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); -// config 8 -template cutlass::Status - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams); - -std::vector> - conv2d_bias_relu_all_func = { - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 32, 32>>, - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<32, 64, 32>>, - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasReluImpl, - cutlass::gemm::GemmShape<64, 32, 32>>}; -std::map, int> map_problem_conv2d_bias_relu; -std::mutex conv2d_bias_relu_mutex; - -void Conv2dBiasRelu(ConvAllParams params) { - int batch = params.batch; - int ic = params.ic; - int ih = params.ih; - int iw = params.iw; - int kh = params.kh; - int kw = params.kw; - int oc = params.oc; - int pad_h0 = params.pad_h0; - int pad_w0 = params.pad_w0; - - int stride_h = params.stride_h; - int stride_w = params.stride_w; - - std::vector problem_size = { - batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w}; - - if (map_problem_conv2d_bias_relu.count(problem_size)) { - conv2d_bias_relu_all_func[map_problem_conv2d_bias_relu.at(problem_size)]( - params); - return; - } - - int best_config_index = ProfileToGetBestConfig( - conv2d_bias_relu_all_func, params, CONV2D_BIAS_RELU); - - std::lock_guard guard(conv2d_bias_relu_mutex); - map_problem_conv2d_bias_relu[problem_size] = best_config_index; - conv2d_bias_relu_all_func[best_config_index](params); -} -} // namespace cutlass_internal -} // namespace fusion -} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_residual.py b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_residual.py new file mode 100644 index 0000000000..97dce6116e --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_residual.py @@ -0,0 +1,214 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +sys.path.append("../") +import enum + +from conv2d_common import ( + CommonConvFunction, + CommonCutlassConvKernelDeclare, + CommonCutlassConvKernelExecute, + CommonTail, + GenerateFunctionForPhi, +) +from util import SubstituteTemplate, TileDesc + +# this is a file's header part + +cbr_header = ''' +// Generated by conv2d_bias_residual.py - Do not edit. + +#include +#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" +#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { +''' + +# This is a cutlass kernel, will be many these like kernels + +dict_for_declare_part = { + "conv_kind_name": "FpropWithBroadcast", + "epi_part": "cutlass::epilogue::thread::LinearCombinationResidualBlock< ${element_c}, ${element_accum}, ${element_epilogue}, ${element_residul}, ${epilogue_vector_length}, ${act1}, ${binary}, ${act2}>", +} + +cbr_kernel = ( + SubstituteTemplate(CommonCutlassConvKernelDeclare, dict_for_declare_part) + + ''' + const half *residual = params.residual; + typename ImplicitGemm::Arguments arguments{ + problem_size, + {(cutlass::half_t *)input, {ic, ic * iw, ic * iw * ih}}, + {(cutlass::half_t *)(weight), {kc, kc * kw, kc * kw * kh}}, + {(cutlass::half_t *)residual, {oc, oc * ow, oc * ow * oh}}, + {(cutlass::half_t *)output, {oc, oc * ow, oc * ow * oh}}, + {1.f, 1.f}, + cutlass::conv::SplitKMode::kSerial, + (cutlass::half_t *)(bias), nullptr, + 0, oc}; +''' + + CommonCutlassConvKernelExecute +) + + +class CbrAct(enum.Enum): + Identity = 1 + Relu = 2 + Silu = 3 + + +ActTag = { + CbrAct.Identity: 'cutlass::epilogue::thread::Identity', + CbrAct.Silu: 'cutlass::epilogue::thread::SiLu', + CbrAct.Relu: 'cutlass::epilogue::thread::ReLu', +} + +# Some global variables used, now we only support these residual blocks. +SupportedEpilogue = [ + (CbrAct.Silu, "cutlass::plus", CbrAct.Identity), + (CbrAct.Identity, "cutlass::plus", CbrAct.Relu), +] + +UnderScoreName = { + SupportedEpilogue[0]: "conv2d_bias_silu_add", + SupportedEpilogue[1]: "conv2d_bias_add_relu", +} + +CamelName = { + SupportedEpilogue[0]: "Conv2dBiasSiluAdd", + SupportedEpilogue[1]: "Conv2dBiasAddRelu", +} + +# Generate sm75 TensorOp conv code. +# CUTLASS Tensor Core operations are implemented using CUDA's mma instruction. +# Here is mma.m16n8k8. + + +def generate_sm75_1688(): + kernel_dict = { + "conv_kind_name": "Fprop", + "element_a": "cutlass::half_t", + "layout_a": "cutlass::layout::TensorNHWC", + "element_b": "cutlass::half_t", + "layout_b": "cutlass::layout::TensorNHWC", + "element_c": "cutlass::half_t", + "layout_c": "cutlass::layout::TensorNHWC", + "opcode_class": "cutlass::arch::OpClassTensorOp", + "arch": "cutlass::arch::Sm75", + "stages": "2", + "swizzling_functor": "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>", + # alpha is always float! + "element_epilogue": "float", + "math_operator": "cutlass::arch::OpMultiplyAdd", + "element_residul": "cutlass::half_t", + } + + kernel_dict["stride_support"] = "cutlass::conv::StrideSupport::kStrided" + + # iterate over this loop + element_accums = ["cutlass::half_t", "float"] + iterator_algorithms = [ + "cutlass::conv::IteratorAlgorithm::kOptimized", + # "cutlass::conv::IteratorAlgorithm::kAnalytic", + ] + + math_instructions = [ + ( + "16,8,8", + "cutlass::half_t", + "cutlass::half_t", + "cutlass::half_t", + ), + ( + "16,8,8", + "cutlass::half_t", + "cutlass::half_t", + "float", + ), + ] + + alignments = [8] + + kernel_dict["align_a"] = "8" + kernel_dict["align_b"] = "8" + kernel_dict["epilogue_vector_length"] = "8" + + sm75_code = "" + for epi_res_block in SupportedEpilogue: + op_dict = {} + op_dict["func_name"] = UnderScoreName[epi_res_block].lower() + "_sm75" + op_dict["enum_op_name"] = UnderScoreName[epi_res_block].upper() + # for a op, we record all its kernels into a std::vector in C++ code + all_kernel_names = "" + suffix = 0 + for iterator_algorithm in iterator_algorithms: + for alignment in alignments: + for math_inst in math_instructions: + tiles = [ + TileDesc("64, 64, 64", 2, "32, 32, 64", math_inst), + TileDesc("64, 32, 64", 2, "32, 32, 64", math_inst), + TileDesc("128, 32, 64", 2, "32, 32, 64", math_inst), + TileDesc("128, 64, 64", 2, "32, 32, 64", math_inst), + TileDesc("64, 64, 32", 2, "32, 32, 32", math_inst), + TileDesc("64, 128, 32", 2, "32, 64, 32", math_inst), + # diff is too large, so comment it + # TileDesc("64, 128, 64", 2, "64, 64, 32", math_inst), + TileDesc("64, 256, 32", 2, "64, 64, 32", math_inst), + TileDesc("128, 64, 32", 2, "64, 32, 32", math_inst), + TileDesc("128, 128, 32", 2, "64, 64, 32", math_inst), + TileDesc("128, 256, 32", 2, "64, 64, 32", math_inst), + TileDesc("256, 64, 32", 2, "64, 64, 32", math_inst), + TileDesc("256, 128, 32", 2, "64, 64, 32", math_inst), + ] + for tile in tiles: + kernel_dict["iterator_algorithm"] = iterator_algorithm + kernel_dict["Tshape"] = tile.Tshape + kernel_dict["Wshape"] = tile.Wshape + kernel_dict["Ishape"] = tile.math_inst[0] + kernel_dict["element_accum"] = tile.math_inst[3] + kernel_dict["kernel_func_name"] = op_dict[ + "func_name" + ] + str(suffix) + kernel_dict["act1"] = ActTag[epi_res_block[0]] + kernel_dict["binary"] = epi_res_block[1] + kernel_dict["act2"] = ActTag[epi_res_block[2]] + suffix += 1 + + sm75_code += SubstituteTemplate(cbr_kernel, kernel_dict) + all_kernel_names += ( + kernel_dict["kernel_func_name"] + ", \n" + ) + + # Generate op code with sm_version + op_dict["all_kernel_func_name"] = all_kernel_names + sm75_code += SubstituteTemplate(CommonConvFunction, op_dict) + return sm75_code + + +if __name__ == "__main__": + sm_versions = ["75"] + all_code = cbr_header + all_code += generate_sm75_1688() + all_code += GenerateFunctionForPhi( + sm_versions, SupportedEpilogue, UnderScoreName, CamelName + ) + all_code += CommonTail + with open("generated/conv2d_bias_residual.cu", "w") as f: + f.write(all_code) + f.close() diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_silu.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_silu.cu deleted file mode 100644 index b763a40d97..0000000000 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_silu.cu +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include "cutlass/conv/kernel/default_conv2d_fprop.h" -#include "cutlass/epilogue/thread/linear_combination_silu.h" -#include "paddle/phi/common/memory_utils.h" -#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" - -namespace phi { -namespace fusion { -namespace cutlass_internal { -template -cutlass::Status Conv2dBiasSiluImpl(ConvAllParams params) { - using ElementAccumulator = float; - using ElementComputeEpilogue = float; - using ElementInputA = cutlass::half_t; - using ElementInputB = cutlass::half_t; - using ElementOutput = cutlass::half_t; - using LayoutInputA = cutlass::layout::TensorNHWC; - using LayoutInputB = cutlass::layout::TensorNHWC; - using LayoutOutput = cutlass::layout::TensorNHWC; - using MMAOp = cutlass::arch::OpClassTensorOp; - using SmArch = cutlass::arch::Sm75; - using ThreadblockShape = TShape; - using WarpShape = WShape; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; - using SwizzleThreadBlock = - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; - constexpr int NumStages = 2; - static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = - cutlass::conv::IteratorAlgorithm::kOptimized; - using EpilogueOp = - cutlass::epilogue::thread::LinearCombinationSilu; - - using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< - ElementInputA, - LayoutInputA, - ElementInputB, - LayoutInputB, - ElementOutput, - LayoutOutput, - ElementAccumulator, - MMAOp, - SmArch, - ThreadblockShape, - WarpShape, - InstructionShape, - EpilogueOp, - SwizzleThreadBlock, - NumStages, - cutlass::arch::OpMultiplyAdd, - IteratorAlgorithm, - cutlass::conv::StrideSupport::kStrided, - Alignment, - Alignment>::Kernel; - using ImplicitGemm = - cutlass::conv::device::ImplicitGemmConvolution; - - const half *input = params.input; - const half *weight = params.weight; - const half *bias = params.bias; - half *output = params.output; - int batch = params.batch; - int ic = params.ic; - int ih = params.ih; - int iw = params.iw; - int kh = params.kh; - int kw = params.kw; - int oc = params.oc; - int pad_h0 = params.pad_h0; - int pad_w0 = params.pad_w0; - int stride_h = params.stride_h; - int stride_w = params.stride_w; - - int oh = params.oh; - int ow = params.ow; - int dilation_h = params.dilation_h; - int dilation_w = params.dilation_w; - - cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; - cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic}, - {oc, kh, kw, ic}, - {pad_h0, 0, pad_w0, 0}, - {stride_h, stride_w}, - {dilation_h, dilation_w}, - {batch, oh, ow, oc}, - mode, - 1); - - typename ImplicitGemm::Arguments arguments{ - problem_size, - {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}}, - {(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}}, - {(cutlass::half_t *)(bias), {0, 0, 0}}, - {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}}, - {1.f, 1.f}}; - - ImplicitGemm implicit_gemm_op; - size_t bytes = implicit_gemm_op.get_workspace_size(arguments); - - auto ctx = params.ctx; - auto stream = ctx->stream(); - phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc( - ctx->GetPlace(), - bytes, - phi::Stream(reinterpret_cast(stream))); - void *workspace = tmp_gpu_ptrs_data->ptr(); - - cutlass::Status status = implicit_gemm_op.can_implement(arguments); - CUTLASS_CHECK(status); - status = implicit_gemm_op.initialize(arguments, workspace); - CUTLASS_CHECK(status); - status = implicit_gemm_op(stream); - CUTLASS_CHECK(status); - return status; -} - -// config 0 -template cutlass::Status - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 1 -template cutlass::Status - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 2 -template cutlass::Status - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 3 -template cutlass::Status - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); -// config 4 -template cutlass::Status - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams); -// config 5 -template cutlass::Status - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams); -// config 6 -template cutlass::Status - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); -// config 7 -template cutlass::Status - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); -// config 8 -template cutlass::Status - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams); - -std::vector> - conv2d_bias_silu_all_func = { - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 32, 64>>, - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 32, 32>>, - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<32, 64, 32>>, - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<64, 64, 32>>, - Conv2dBiasSiluImpl, - cutlass::gemm::GemmShape<64, 32, 32>>}; - -std::map, int> map_problem_conv2d_bias_silu; -std::mutex conv2d_bias_silu_mutex; - -void Conv2dBiasSilu(ConvAllParams params) { - int batch = params.batch; - int ic = params.ic; - int ih = params.ih; - int iw = params.iw; - int kh = params.kh; - int kw = params.kw; - int oc = params.oc; - int pad_h0 = params.pad_h0; - int pad_w0 = params.pad_w0; - int stride_h = params.stride_h; - int stride_w = params.stride_w; - - std::vector problem_size = { - batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w}; - - if (map_problem_conv2d_bias_silu.count(problem_size)) { - conv2d_bias_silu_all_func[map_problem_conv2d_bias_silu.at(problem_size)]( - params); - return; - } - - int best_config_index = ProfileToGetBestConfig( - conv2d_bias_silu_all_func, params, CONV2D_BIAS_SILU); - - std::lock_guard guard(conv2d_bias_silu_mutex); - - map_problem_conv2d_bias_silu[problem_size] = best_config_index; - conv2d_bias_silu_all_func[best_config_index](params); -} -} // namespace cutlass_internal -} // namespace fusion -} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py new file mode 100644 index 0000000000..c99cbe5a11 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_common.py @@ -0,0 +1,204 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys + +sys.path.append("../") +from util import SubstituteTemplate + +# For beginners, these template parameters may be difficult to understand. +# Please refer to the conv-related demo of CUTLASS for better understanding. +# https://github.com/NVIDIA/cutlass/tree/master/examples + +CommonCutlassConvKernelDeclare = """ +cutlass::Status ${kernel_func_name}(const ConvAllParams& params) { + using kernel_base = + typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}< + ${element_a}, + ${layout_a}, + ${element_b}, + ${layout_b}, + ${element_c}, + ${layout_c}, + ${element_accum}, + ${opcode_class}, + ${arch}, + cutlass::gemm::GemmShape<${Tshape}>, + cutlass::gemm::GemmShape<${Wshape}>, + cutlass::gemm::GemmShape<${Ishape}>, + ${epi_part}, + ${swizzling_functor}, + ${stages}, + ${math_operator}, + ${iterator_algorithm}, + ${stride_support}, + ${align_a}, + ${align_b} + >::Kernel; + + using ImplicitGemm = + cutlass::conv::device::ImplicitGemmConvolution; + const half *input = params.input; + const half *weight = params.weight; + const half *bias = params.bias; + half *output = params.output; + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w0; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + int groups = params.groups; + int kc = ic / groups; + + int oh = params.oh; + int ow = params.ow; + int dilation_h = params.dilation_h; + int dilation_w = params.dilation_w; + + cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic}, + {oc, kh, kw, ic / groups}, + {pad_h0, 0, pad_w0, 0}, + {stride_h, stride_w}, + {dilation_h, dilation_w}, + {batch, oh, ow, oc}, + cutlass::conv::Mode::kCrossCorrelation, + 1, + groups); +""" + +# This is the execution part of this cutlass conv kernel. + +CommonCutlassConvKernelExecute = """ + ImplicitGemm implicit_gemm_op; + size_t bytes = implicit_gemm_op.get_workspace_size(arguments); + + auto ctx = params.ctx; + auto stream = ctx->stream(); + phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = + phi::memory_utils::Alloc( + ctx->GetPlace(), + bytes, + phi::Stream(reinterpret_cast(stream))); + void *workspace = tmp_gpu_ptrs_data->ptr(); + + cutlass::Status status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + status = implicit_gemm_op.initialize(arguments, workspace); + CUTLASS_CHECK(status); + status = implicit_gemm_op(stream); + CUTLASS_CHECK(status); + return status; +} +""" + +# CommonConvFunction is a wrapper for many kernels +# a func_name is like conv2d_bias_silu_sm75 +# it has many kernels, we should pick up a performence-best +# ${func_name} is like conv2d_bias_silu_sm75 +# ${enum_op_name} is like CONV2D_BIAS_SILU + +CommonConvFunction = """ +std::vector> + ${func_name}_all_func = {${all_kernel_func_name}}; + +std::map, int> map_problem_${func_name}; +std::mutex ${func_name}_mutex; + +void ${func_name}(const ConvAllParams& params) { + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + //int pad_h0 = params.pad_h0; + //int pad_w0 = params.pad_w0; + int groups = params.groups; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + + std::vector problem_size = { + batch, ic, ih, iw, kh, kw, oc, groups, stride_h, stride_w}; + + if (map_problem_${func_name}.count(problem_size)) { + ${func_name}_all_func[map_problem_${func_name}.at(problem_size)]( + params); + return; + } + + int best_config_index = ProfileToGetBestConfig( + ${func_name}_all_func, params, ${enum_op_name}); + + std::lock_guard guard(${func_name}_mutex); + + map_problem_${func_name}[problem_size] = best_config_index; + ${func_name}_all_func[best_config_index](params); +} +""" + + +# We should wrapper all op_name_with_sm_version into a function +# like : wrapper conv2d_bias_silu_sm75, conv2d_bias_silu_sm80, conv2d_bias_silu_sm86 into conv2d_bias_silu for phi kernel +# this function is invoked by phi kernel + +CommonWrapperForPhi = """ +void ${op_name}(const ConvAllParams& params) { + ${dispatch_body} +} +""" + + +CommonDispatchTemp = ''' + if (params.sm_version == ${sm_code}) + { + ${op_name_with_sm}(params); + } + ''' + + +# this is a file's ending part + +CommonTail = ''' +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi +''' + + +# wrap different sm versions into a function +def GenerateFunctionForPhi( + sm_versions, support_epi_funcs, underscore_names, camel_names +): + generated_code = "" + for epi_func in support_epi_funcs: + dispatch_body = "" + for sm_version in sm_versions: + sm_dicts = {} + sm_dicts["sm_code"] = sm_version + sm_dicts["op_name_with_sm"] = ( + underscore_names[epi_func].lower() + "_sm" + sm_version + ) + dispatch_body += SubstituteTemplate(CommonDispatchTemp, sm_dicts) + op_dicts = {} + op_dicts["dispatch_body"] = dispatch_body + op_dicts["op_name"] = camel_names[epi_func] + generated_code += SubstituteTemplate(CommonWrapperForPhi, op_dicts) + return generated_code diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h index b740d49fc1..0068b12e7c 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h @@ -46,16 +46,18 @@ typedef struct { int dilation_w; int oh; int ow; + int groups; const phi::GPUContext *ctx; float alpha; // for leaky_relu use + int sm_version = 75; } ConvAllParams; // Below functions are provided by cutlass, they are called by phi. -void Conv2dBiasAddRelu(ConvAllParams params); -void Conv2dBiasRelu(ConvAllParams params); -void Conv2dBiasLeakyRelu(ConvAllParams params); -void Conv2dBiasSilu(ConvAllParams params); -void Conv2dBias(ConvAllParams params); +void Conv2dBiasAddRelu(const ConvAllParams ¶ms); +void Conv2dBiasRelu(const ConvAllParams ¶ms); +void Conv2dBiasLeakyRelu(const ConvAllParams ¶ms); +void Conv2dBiasSilu(const ConvAllParams ¶ms); +void Conv2dBias(const ConvAllParams ¶ms); } // namespace cutlass_internal } // namespace fusion } // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu index 174cb4aaa4..63d9ffae3a 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu @@ -62,12 +62,14 @@ __global__ void naive_conv2d_kernel(const half *input, int dilation_w, int oh, int ow, + int groups, const half *residual, float alpha, // for leaky_relu OpType op_type) { int M = batch * oh * ow; int N = oc; - int K = ic * kh * kw; + int kc = ic / groups; + int K = kc * kh * kw; int m_i = threadIdx.x + blockIdx.x * blockDim.x; int n_i = threadIdx.y + blockIdx.y * blockDim.y; if (m_i >= M || n_i >= N) return; @@ -76,19 +78,20 @@ __global__ void naive_conv2d_kernel(const half *input, int oh_i = (m_i % (oh * ow)) / ow; int ow_i = (m_i % (oh * ow)) % ow; int oc_i = n_i; + int groups_i = (oc_i / (oc / groups)); - struct logical_coord weight_shape = {oc, ic, kh, kw}; + struct logical_coord weight_shape = {oc, kc, kh, kw}; struct logical_coord input_shape = {batch, ic, ih, iw}; int out_offset = m_i * N + n_i; float *out_ptr = output + out_offset; float sum = 0.f; for (int k_i = 0; k_i < K; k_i++) { - int ic_i = k_i / (kh * kw); + int ic_i = k_i / (kh * kw) + groups_i * kc; int kh_i = (k_i % (kh * kw)) / kw; int kw_i = (k_i % (kh * kw)) % kw; - struct logical_coord weight_index = {oc_i, ic_i, kh_i, kw_i}; + struct logical_coord weight_index = {oc_i, k_i / (kh * kw), kh_i, kw_i}; int ih_i = oh_i * stride_h - pad_h + kh_i * dilation_h; int iw_i = ow_i * stride_w - pad_w + kw_i * dilation_w; @@ -127,7 +130,7 @@ __global__ void naive_conv2d_kernel(const half *input, } } -float conv2d_diff_gpu(ConvAllParams params, OpType op_type) { +float conv2d_diff_gpu(const ConvAllParams ¶ms, OpType op_type) { const half *input = params.input; const half *weight = params.weight; const half *bias = params.bias; @@ -146,6 +149,7 @@ float conv2d_diff_gpu(ConvAllParams params, OpType op_type) { int dilation_h = params.dilation_h; int dilation_w = params.dilation_w; const half *residual = params.residual; + int groups = params.groups; int oh = params.oh; int ow = params.ow; @@ -186,6 +190,7 @@ float conv2d_diff_gpu(ConvAllParams params, OpType op_type) { dilation_w, oh, ow, + groups, residual, params.alpha, op_type); @@ -212,7 +217,7 @@ std::string OpType2String(OpType op_type) { return "conv2d_bias_relu"; break; case CONV2D_BIAS_SILU: - return "conv2d_bias_add_silu"; + return "conv2d_bias_silu"; break; case CONV2D_BIAS_ADD_RELU: return "conv2d_bias_add_relu"; @@ -227,7 +232,7 @@ std::string OpType2String(OpType op_type) { int ProfileToGetBestConfig( const std::vector> &all_func, - ConvAllParams params, + const ConvAllParams ¶ms, OpType op_type) { constexpr int WARMUP = 10; constexpr int REPEAT = 100; @@ -258,10 +263,11 @@ int ProfileToGetBestConfig( if (elapsed_time < min_time && status == cutlass::Status::kSuccess) { min_time = elapsed_time; min_time_index = i; + // debug code + VLOG(3) << OpType2String(op_type) << ": tactic " << i << " has max diff " + << conv2d_diff_gpu(params, op_type) << " compared with baseline," + << "cost_time: " << elapsed_time << "ms."; } - // debug code - VLOG(3) << OpType2String(op_type) << ": tactic " << i << " has max diff " - << conv2d_diff_gpu(params, op_type) << " compared with baseline."; } if (min_time_index < 0) { diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h index 59bee67964..710de13a51 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h @@ -23,6 +23,7 @@ #include "cutlass/conv/device/implicit_gemm_convolution.h" #include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/core/enforce.h" namespace phi { @@ -40,16 +41,17 @@ typedef enum { CONV2D_BIAS_RELU, CONV2D_BIAS_ADD_RELU, CONV2D_BIAS_SILU, - CONV2D_BIAS_LEAKY_RELU + CONV2D_BIAS_LEAKY_RELU, + CONV2D_BIAS_SILU_ADD } OpType; // conv2d_diff_gpu calculate diff of cutlass output and baseline output, you can // use them to debug. return value is the max diff between cutlass and baseline. -float conv2d_diff_gpu(ConvAllParams params, OpType op_type); +float conv2d_diff_gpu(const ConvAllParams& params, OpType op_type); int ProfileToGetBestConfig( const std::vector>& all_func, - ConvAllParams params, + const ConvAllParams& params, OpType op_type); } // namespace cutlass_internal diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d_fusion.cu b/paddle/phi/kernels/fusion/cutlass/conv2d_fusion.cu index aea3c62494..ff83e5db4d 100644 --- a/paddle/phi/kernels/fusion/cutlass/conv2d_fusion.cu +++ b/paddle/phi/kernels/fusion/cutlass/conv2d_fusion.cu @@ -43,12 +43,14 @@ void Conv2dFusionKernel(const Context& ctx, CHECK_EQ(filter_dims.size() == 4UL, true); CHECK_EQ(strides.size() == 2UL, true); CHECK_EQ(dilations.size() == 2UL, true); - CHECK_EQ(groups == 1, true); + CHECK_EQ(padding_algorithm == "EXPLICIT", true); const int batch = in_dims[0]; const int ic = in_dims[3]; const int ih = in_dims[1]; const int iw = in_dims[2]; + CHECK_EQ(groups == 1, true); + CHECK_EQ(ic == groups * filter_dims[3], true); int pad_h0 = 0; int pad_h1 = 0; int pad_w0 = 0; @@ -104,6 +106,7 @@ void Conv2dFusionKernel(const Context& ctx, dilation_w, oh, ow, + groups, &ctx}; if (residual) { diff --git a/paddle/phi/kernels/fusion/cutlass/util.py b/paddle/phi/kernels/fusion/cutlass/util.py new file mode 100644 index 0000000000..200960f39c --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/util.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re + + +class TileDesc: + def __init__(self, Tshape, stages, Wshape, math_inst): + self.Tshape = Tshape + self.stages = stages + self.Wshape = Wshape + self.math_inst = math_inst + + +def SubstituteTemplate(template, values): + text = template + changed = True + while changed: + changed = False + for key, value in values.items(): + regex = "\\$\\{%s\\}" % key + newtext = re.sub(regex, value, text) + if newtext != text: + changed = True + text = newtext + return text diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_cutlass_conv2d_fusion_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_cutlass_conv2d_fusion_op.py index 8adeff0f73..2288b76f8b 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_cutlass_conv2d_fusion_op.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_cutlass_conv2d_fusion_op.py @@ -27,7 +27,7 @@ import paddle.inference as paddle_infer class TestCutlassConv2dFusionOp1(CutlassAutoScanTest): def sample_program_configs(self, *args, **kwargs): def generate_input1(input_shape): - return np.random.random(input_shape).astype(np.float32) + return (np.random.random(input_shape) - 0.5).astype(np.float32) def generate_weight(weight_shape): return np.random.random(weight_shape).astype(np.float32) @@ -105,7 +105,9 @@ class TestCutlassConv2dFusionOp1(CutlassAutoScanTest): "op_type": act, "op_inputs": {"X": ["output_data0"]}, "op_outputs": {"Out": ["output_data1"]}, - "op_attrs": {}, + "op_attrs": { + "alpha": 2.0, + }, }, ] -- GitLab