未验证 提交 4e9e23cb 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference ]use python to generate cutlass code (#50603)

* use python to generate cutlass code

* refine CommonConvKernelPart1, CommonConvKernelPart2

* remove useless code in generate_cutlass_code.sh

* add more config in conv2d_residual

* CommonCutlassConvKernelPart1 and CommonCutlassConvKernelPart2

* add group conv support in util.cu

* remove .sh

* refine name

* make name goodgit status!

* add fuse_alpha

* make code easy to understand

* mot fopen generate in py

* use python script to generate conv2d,group=1 cutlass code

* use const &

* use const & && use python script to generate conv2d/group=1 code
上级 39899d79
......@@ -96,4 +96,5 @@ paddle/fluid/prim/api/generated/prim_api/*
paddle/fluid/framework/__init__.py
paddle/phi/api/profiler/__init__.py
python/paddle/incubate/fleet/parameter_server/pslib/ps_pb2.py
paddle/phi/kernels/fusion/cutlass/conv2d/generated/*
python/paddle/fluid/incubate/fleet/parameter_server/pslib/ps_pb2.py
......@@ -207,6 +207,10 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
op_desc->SetAttr("data_format", std::string{"NHWC"});
if (cutlass_enable && CutlassIsValid(op_node)) {
op_desc->SetType("conv2d_fusion_cutlass");
// conv2d_fusion_cutlass must have this attribute because of signature.
if (!op_desc->HasAttr("fuse_alpha")) {
op_desc->SetAttr("fuse_alpha", 0.f);
}
}
op_desc->Flush();
......
......@@ -114,7 +114,15 @@ file(
"fusion/gpu/*.cu")
if(WITH_CUTLASS)
file(GLOB cutlass_cu "fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu")
execute_process(
COMMAND ${CMAKE_COMMAND} -E make_directory
"${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/conv2d/generated"
COMMAND ${PYTHON_EXECUTABLE} "conv2d_bias_act.py"
COMMAND ${PYTHON_EXECUTABLE} "conv2d_bias_residual.py"
WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/conv2d")
file(GLOB cutlass_cu "fusion/cutlass/conv2d/generated/*.cu"
"fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu")
list(APPEND kernel_cu ${cutlass_cu})
endif()
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp =
cutlass::epilogue::thread::LinearCombination<ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 1
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 2
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 3
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 4
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(
ConvAllParams);
// config 5
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(
ConvAllParams);
// config 6
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(
ConvAllParams);
// config 7
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(
ConvAllParams);
// config 8
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(
ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_all_func = {
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias;
std::mutex conv2d_bias_mutex;
void Conv2dBias(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias.count(problem_size)) {
conv2d_bias_all_func[map_problem_conv2d_bias.at(problem_size)](params);
return;
}
int best_config_index =
ProfileToGetBestConfig(conv2d_bias_all_func, params, CONV2D_BIAS);
std::lock_guard<std::mutex> guard(conv2d_bias_mutex);
map_problem_conv2d_bias[problem_size] = best_config_index;
conv2d_bias_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import enum
from conv2d_common import (
CommonConvFunction,
CommonCutlassConvKernelDeclare,
CommonCutlassConvKernelExecute,
CommonTail,
GenerateFunctionForPhi,
)
from util import SubstituteTemplate, TileDesc
# this is a file's header part
cba_header = '''
// Generated by conv2d_bias_act.py - Do not edit.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_leaky_relu.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
'''
# This is a cutlass kernel, will be many these like kernels
dict_for_declare_part = {
"conv_kind_name": "Fprop",
"epi_part": "${epi_func}< ${element_c}, ${epilogue_vector_length}, ${element_accum}, ${element_epilogue}>",
}
cba_kernel_no_alpha = (
SubstituteTemplate(CommonCutlassConvKernelDeclare, dict_for_declare_part)
+ '''
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {kc, kc * kw, kc * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
'''
+ CommonCutlassConvKernelExecute
)
# This is used for leaky_relu, this activation need a fuse_alpha parameter.
cba_kernel_alpha = cba_kernel_no_alpha.replace(
"{1.f, 1.f}", "{1.f, 1.f, alpha}"
).replace(
"typename ImplicitG", "float alpha = params.alpha; typename ImplicitG"
)
class CbaAct(enum.Enum):
Identity = 1
Relu = 2
Silu = 3
LeakyRelu = 4
# Some global variables used, now we only support these activations.
SupportedAct = [
CbaAct.Identity,
CbaAct.Relu,
CbaAct.Silu,
CbaAct.LeakyRelu,
]
ActTag = {
SupportedAct[0]: 'cutlass::epilogue::thread::LinearCombination',
SupportedAct[1]: 'cutlass::epilogue::thread::LinearCombinationRelu',
SupportedAct[2]: 'cutlass::epilogue::thread::LinearCombinationSilu',
SupportedAct[3]: 'cutlass::epilogue::thread::LinearCombinationLeakyRelu',
}
UnderScoreName = {
SupportedAct[0]: "conv2d_bias",
SupportedAct[1]: "conv2d_bias_relu",
SupportedAct[2]: "conv2d_bias_silu",
SupportedAct[3]: "conv2d_bias_leaky_relu",
}
CamelName = {
SupportedAct[0]: "Conv2dBias",
SupportedAct[1]: "Conv2dBiasRelu",
SupportedAct[2]: "Conv2dBiasSilu",
SupportedAct[3]: "Conv2dBiasLeakyRelu",
}
# Generate sm75 TensorOp conv code.
# CUTLASS Tensor Core operations are implemented using CUDA's mma instruction.
# Here is mma.m16n8k8.
def generate_sm75_1688():
kernel_dict = {
"element_a": "cutlass::half_t",
"layout_a": "cutlass::layout::TensorNHWC",
"element_b": "cutlass::half_t",
"layout_b": "cutlass::layout::TensorNHWC",
"element_c": "cutlass::half_t",
"layout_c": "cutlass::layout::TensorNHWC",
"opcode_class": "cutlass::arch::OpClassTensorOp",
"arch": "cutlass::arch::Sm75",
"stages": "2",
"swizzling_functor": "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
# alpha is always float!
"element_epilogue": "float",
"math_operator": "cutlass::arch::OpMultiplyAdd",
}
kernel_dict["stride_support"] = "cutlass::conv::StrideSupport::kStrided"
# iterate over this loop
element_accums = ["cutlass::half_t", "float"]
iterator_algorithms = [
"cutlass::conv::IteratorAlgorithm::kOptimized",
# "cutlass::conv::IteratorAlgorithm::kAnalytic",
]
math_instructions = [
(
"16,8,8",
"cutlass::half_t",
"cutlass::half_t",
"cutlass::half_t",
),
(
"16,8,8",
"cutlass::half_t",
"cutlass::half_t",
"float",
),
]
alignments = [8]
kernel_dict["align_a"] = "8"
kernel_dict["align_b"] = "8"
# this should divided by oc
kernel_dict["epilogue_vector_length"] = "8"
sm75_code = ""
for epi_func in SupportedAct:
op_dict = {}
op_dict["func_name"] = UnderScoreName[epi_func].lower() + "_sm75"
op_dict["enum_op_name"] = UnderScoreName[epi_func].upper()
# For a function, we record all its kernels into a std::vector in C++ code
all_kernel_names = ""
kernel_dict["epi_func"] = ActTag[epi_func]
suffix = 0
for iterator_algorithm in iterator_algorithms:
for alignment in alignments:
for math_inst in math_instructions:
tiles = [
TileDesc("64, 64, 64", 2, "32, 32, 64", math_inst),
TileDesc("64, 32, 64", 2, "32, 32, 64", math_inst),
TileDesc("128, 32, 64", 2, "32, 32, 64", math_inst),
TileDesc("128, 64, 64", 2, "32, 32, 64", math_inst),
TileDesc("64, 64, 32", 2, "32, 32, 32", math_inst),
TileDesc("64, 128, 32", 2, "32, 64, 32", math_inst),
TileDesc("64, 128, 64", 2, "64, 64, 32", math_inst),
TileDesc("64, 256, 32", 2, "64, 64, 32", math_inst),
TileDesc("128, 64, 32", 2, "64, 32, 32", math_inst),
]
for tile in tiles:
kernel_dict["iterator_algorithm"] = iterator_algorithm
kernel_dict["Tshape"] = tile.Tshape
kernel_dict["Wshape"] = tile.Wshape
kernel_dict["Ishape"] = tile.math_inst[0]
kernel_dict["element_accum"] = tile.math_inst[3]
kernel_dict["kernel_func_name"] = op_dict[
"func_name"
] + str(suffix)
suffix += 1
cba_kernel = cba_kernel_no_alpha
if epi_func in [CbaAct.LeakyRelu]:
cba_kernel = cba_kernel_alpha
sm75_code += SubstituteTemplate(cba_kernel, kernel_dict)
all_kernel_names += (
kernel_dict["kernel_func_name"] + ", \n"
)
# Generate op code
op_dict["all_kernel_func_name"] = all_kernel_names
sm75_code += SubstituteTemplate(CommonConvFunction, op_dict)
return sm75_code
if __name__ == "__main__":
sm_versions = ["75"]
all_code = cba_header
all_code += generate_sm75_1688()
all_code += GenerateFunctionForPhi(
sm_versions, SupportedAct, UnderScoreName, CamelName
)
all_code += CommonTail
with open("generated/conv2d_bias_act.cu", "w") as f:
f.write(all_code)
f.close()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasAddReluImpl(ConvAllParams params) {
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock<
cutlass::half_t,
float,
float,
cutlass::half_t,
Alignment,
cutlass::epilogue::thread::Identity,
cutlass::plus,
cutlass::epilogue::thread::ReLu>;
using Conv2dFpropKernel =
typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast<
cutlass::half_t,
cutlass::layout::TensorNHWC,
cutlass::half_t,
cutlass::layout::TensorNHWC,
cutlass::half_t,
cutlass::layout::TensorNHWC,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
TShape,
WShape,
cutlass::gemm::GemmShape<16, 8, 8>,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>,
2,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::IteratorAlgorithm::kOptimized,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
const half *residual = params.residual;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Conv2dProblemSize problem_size(
{batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
cutlass::conv::Mode::kCrossCorrelation,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)input, {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)weight, {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)residual, {oc, oc * ow, oc * ow * oh}},
{(cutlass::half_t *)output, {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f},
cutlass::conv::SplitKMode::kSerial,
(cutlass::half_t *)(bias),
nullptr,
0,
oc};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
// config 9
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 10
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 11
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 12
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_add_relu_all_func = {
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_add_relu;
std::mutex conv2d_bias_add_relu_mutex;
void Conv2dBiasAddRelu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_add_relu.count(problem_size)) {
conv2d_bias_add_relu_all_func[map_problem_conv2d_bias_add_relu.at(
problem_size)](params);
return;
}
std::lock_guard<std::mutex> guard(conv2d_bias_add_relu_mutex);
// config 6's diff is large.
conv2d_bias_add_relu_all_func[6] = nullptr;
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_add_relu_all_func, params, CONV2D_BIAS_ADD_RELU);
map_problem_conv2d_bias_add_relu[problem_size] = best_config_index;
conv2d_bias_add_relu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_leaky_relu.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasLeakyReluImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationLeakyRelu<
ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
float alpha = params.alpha;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f, alpha}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_leaky_relu_all_func = {
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_leaky_relu;
std::mutex conv2d_bias_leaky_relu_mutex;
void Conv2dBiasLeakyRelu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_leaky_relu.count(problem_size)) {
conv2d_bias_leaky_relu_all_func[map_problem_conv2d_bias_leaky_relu.at(
problem_size)](params);
return;
}
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_leaky_relu_all_func, params, CONV2D_BIAS_LEAKY_RELU);
std::lock_guard<std::mutex> guard(conv2d_bias_leaky_relu_mutex);
map_problem_conv2d_bias_leaky_relu[problem_size] = best_config_index;
conv2d_bias_leaky_relu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasReluImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp =
cutlass::epilogue::thread::LinearCombinationRelu<ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_relu_all_func = {
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_relu;
std::mutex conv2d_bias_relu_mutex;
void Conv2dBiasRelu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_relu.count(problem_size)) {
conv2d_bias_relu_all_func[map_problem_conv2d_bias_relu.at(problem_size)](
params);
return;
}
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_relu_all_func, params, CONV2D_BIAS_RELU);
std::lock_guard<std::mutex> guard(conv2d_bias_relu_mutex);
map_problem_conv2d_bias_relu[problem_size] = best_config_index;
conv2d_bias_relu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import enum
from conv2d_common import (
CommonConvFunction,
CommonCutlassConvKernelDeclare,
CommonCutlassConvKernelExecute,
CommonTail,
GenerateFunctionForPhi,
)
from util import SubstituteTemplate, TileDesc
# this is a file's header part
cbr_header = '''
// Generated by conv2d_bias_residual.py - Do not edit.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
'''
# This is a cutlass kernel, will be many these like kernels
dict_for_declare_part = {
"conv_kind_name": "FpropWithBroadcast",
"epi_part": "cutlass::epilogue::thread::LinearCombinationResidualBlock< ${element_c}, ${element_accum}, ${element_epilogue}, ${element_residul}, ${epilogue_vector_length}, ${act1}, ${binary}, ${act2}>",
}
cbr_kernel = (
SubstituteTemplate(CommonCutlassConvKernelDeclare, dict_for_declare_part)
+ '''
const half *residual = params.residual;
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)input, {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {kc, kc * kw, kc * kw * kh}},
{(cutlass::half_t *)residual, {oc, oc * ow, oc * ow * oh}},
{(cutlass::half_t *)output, {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f},
cutlass::conv::SplitKMode::kSerial,
(cutlass::half_t *)(bias), nullptr,
0, oc};
'''
+ CommonCutlassConvKernelExecute
)
class CbrAct(enum.Enum):
Identity = 1
Relu = 2
Silu = 3
ActTag = {
CbrAct.Identity: 'cutlass::epilogue::thread::Identity',
CbrAct.Silu: 'cutlass::epilogue::thread::SiLu',
CbrAct.Relu: 'cutlass::epilogue::thread::ReLu',
}
# Some global variables used, now we only support these residual blocks.
SupportedEpilogue = [
(CbrAct.Silu, "cutlass::plus", CbrAct.Identity),
(CbrAct.Identity, "cutlass::plus", CbrAct.Relu),
]
UnderScoreName = {
SupportedEpilogue[0]: "conv2d_bias_silu_add",
SupportedEpilogue[1]: "conv2d_bias_add_relu",
}
CamelName = {
SupportedEpilogue[0]: "Conv2dBiasSiluAdd",
SupportedEpilogue[1]: "Conv2dBiasAddRelu",
}
# Generate sm75 TensorOp conv code.
# CUTLASS Tensor Core operations are implemented using CUDA's mma instruction.
# Here is mma.m16n8k8.
def generate_sm75_1688():
kernel_dict = {
"conv_kind_name": "Fprop",
"element_a": "cutlass::half_t",
"layout_a": "cutlass::layout::TensorNHWC",
"element_b": "cutlass::half_t",
"layout_b": "cutlass::layout::TensorNHWC",
"element_c": "cutlass::half_t",
"layout_c": "cutlass::layout::TensorNHWC",
"opcode_class": "cutlass::arch::OpClassTensorOp",
"arch": "cutlass::arch::Sm75",
"stages": "2",
"swizzling_functor": "cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>",
# alpha is always float!
"element_epilogue": "float",
"math_operator": "cutlass::arch::OpMultiplyAdd",
"element_residul": "cutlass::half_t",
}
kernel_dict["stride_support"] = "cutlass::conv::StrideSupport::kStrided"
# iterate over this loop
element_accums = ["cutlass::half_t", "float"]
iterator_algorithms = [
"cutlass::conv::IteratorAlgorithm::kOptimized",
# "cutlass::conv::IteratorAlgorithm::kAnalytic",
]
math_instructions = [
(
"16,8,8",
"cutlass::half_t",
"cutlass::half_t",
"cutlass::half_t",
),
(
"16,8,8",
"cutlass::half_t",
"cutlass::half_t",
"float",
),
]
alignments = [8]
kernel_dict["align_a"] = "8"
kernel_dict["align_b"] = "8"
kernel_dict["epilogue_vector_length"] = "8"
sm75_code = ""
for epi_res_block in SupportedEpilogue:
op_dict = {}
op_dict["func_name"] = UnderScoreName[epi_res_block].lower() + "_sm75"
op_dict["enum_op_name"] = UnderScoreName[epi_res_block].upper()
# for a op, we record all its kernels into a std::vector in C++ code
all_kernel_names = ""
suffix = 0
for iterator_algorithm in iterator_algorithms:
for alignment in alignments:
for math_inst in math_instructions:
tiles = [
TileDesc("64, 64, 64", 2, "32, 32, 64", math_inst),
TileDesc("64, 32, 64", 2, "32, 32, 64", math_inst),
TileDesc("128, 32, 64", 2, "32, 32, 64", math_inst),
TileDesc("128, 64, 64", 2, "32, 32, 64", math_inst),
TileDesc("64, 64, 32", 2, "32, 32, 32", math_inst),
TileDesc("64, 128, 32", 2, "32, 64, 32", math_inst),
# diff is too large, so comment it
# TileDesc("64, 128, 64", 2, "64, 64, 32", math_inst),
TileDesc("64, 256, 32", 2, "64, 64, 32", math_inst),
TileDesc("128, 64, 32", 2, "64, 32, 32", math_inst),
TileDesc("128, 128, 32", 2, "64, 64, 32", math_inst),
TileDesc("128, 256, 32", 2, "64, 64, 32", math_inst),
TileDesc("256, 64, 32", 2, "64, 64, 32", math_inst),
TileDesc("256, 128, 32", 2, "64, 64, 32", math_inst),
]
for tile in tiles:
kernel_dict["iterator_algorithm"] = iterator_algorithm
kernel_dict["Tshape"] = tile.Tshape
kernel_dict["Wshape"] = tile.Wshape
kernel_dict["Ishape"] = tile.math_inst[0]
kernel_dict["element_accum"] = tile.math_inst[3]
kernel_dict["kernel_func_name"] = op_dict[
"func_name"
] + str(suffix)
kernel_dict["act1"] = ActTag[epi_res_block[0]]
kernel_dict["binary"] = epi_res_block[1]
kernel_dict["act2"] = ActTag[epi_res_block[2]]
suffix += 1
sm75_code += SubstituteTemplate(cbr_kernel, kernel_dict)
all_kernel_names += (
kernel_dict["kernel_func_name"] + ", \n"
)
# Generate op code with sm_version
op_dict["all_kernel_func_name"] = all_kernel_names
sm75_code += SubstituteTemplate(CommonConvFunction, op_dict)
return sm75_code
if __name__ == "__main__":
sm_versions = ["75"]
all_code = cbr_header
all_code += generate_sm75_1688()
all_code += GenerateFunctionForPhi(
sm_versions, SupportedEpilogue, UnderScoreName, CamelName
)
all_code += CommonTail
with open("generated/conv2d_bias_residual.cu", "w") as f:
f.write(all_code)
f.close()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasSiluImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp =
cutlass::epilogue::thread::LinearCombinationSilu<ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data = phi::memory_utils::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_silu_all_func = {
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_silu;
std::mutex conv2d_bias_silu_mutex;
void Conv2dBiasSilu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_silu.count(problem_size)) {
conv2d_bias_silu_all_func[map_problem_conv2d_bias_silu.at(problem_size)](
params);
return;
}
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_silu_all_func, params, CONV2D_BIAS_SILU);
std::lock_guard<std::mutex> guard(conv2d_bias_silu_mutex);
map_problem_conv2d_bias_silu[problem_size] = best_config_index;
conv2d_bias_silu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
from util import SubstituteTemplate
# For beginners, these template parameters may be difficult to understand.
# Please refer to the conv-related demo of CUTLASS for better understanding.
# https://github.com/NVIDIA/cutlass/tree/master/examples
CommonCutlassConvKernelDeclare = """
cutlass::Status ${kernel_func_name}(const ConvAllParams& params) {
using kernel_base =
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
${layout_b},
${element_c},
${layout_c},
${element_accum},
${opcode_class},
${arch},
cutlass::gemm::GemmShape<${Tshape}>,
cutlass::gemm::GemmShape<${Wshape}>,
cutlass::gemm::GemmShape<${Ishape}>,
${epi_part},
${swizzling_functor},
${stages},
${math_operator},
${iterator_algorithm},
${stride_support},
${align_a},
${align_b}
>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<kernel_base>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int groups = params.groups;
int kc = ic / groups;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic / groups},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
cutlass::conv::Mode::kCrossCorrelation,
1,
groups);
"""
# This is the execution part of this cutlass conv kernel.
CommonCutlassConvKernelExecute = """
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
phi::Allocator::AllocationPtr tmp_gpu_ptrs_data =
phi::memory_utils::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
"""
# CommonConvFunction is a wrapper for many kernels
# a func_name is like conv2d_bias_silu_sm75
# it has many kernels, we should pick up a performence-best
# ${func_name} is like conv2d_bias_silu_sm75
# ${enum_op_name} is like CONV2D_BIAS_SILU
CommonConvFunction = """
std::vector<std::function<cutlass::Status(const ConvAllParams)>>
${func_name}_all_func = {${all_kernel_func_name}};
std::map<std::vector<int>, int> map_problem_${func_name};
std::mutex ${func_name}_mutex;
void ${func_name}(const ConvAllParams& params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
//int pad_h0 = params.pad_h0;
//int pad_w0 = params.pad_w0;
int groups = params.groups;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, groups, stride_h, stride_w};
if (map_problem_${func_name}.count(problem_size)) {
${func_name}_all_func[map_problem_${func_name}.at(problem_size)](
params);
return;
}
int best_config_index = ProfileToGetBestConfig(
${func_name}_all_func, params, ${enum_op_name});
std::lock_guard<std::mutex> guard(${func_name}_mutex);
map_problem_${func_name}[problem_size] = best_config_index;
${func_name}_all_func[best_config_index](params);
}
"""
# We should wrapper all op_name_with_sm_version into a function
# like : wrapper conv2d_bias_silu_sm75, conv2d_bias_silu_sm80, conv2d_bias_silu_sm86 into conv2d_bias_silu for phi kernel
# this function is invoked by phi kernel
CommonWrapperForPhi = """
void ${op_name}(const ConvAllParams& params) {
${dispatch_body}
}
"""
CommonDispatchTemp = '''
if (params.sm_version == ${sm_code})
{
${op_name_with_sm}(params);
}
'''
# this is a file's ending part
CommonTail = '''
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
'''
# wrap different sm versions into a function
def GenerateFunctionForPhi(
sm_versions, support_epi_funcs, underscore_names, camel_names
):
generated_code = ""
for epi_func in support_epi_funcs:
dispatch_body = ""
for sm_version in sm_versions:
sm_dicts = {}
sm_dicts["sm_code"] = sm_version
sm_dicts["op_name_with_sm"] = (
underscore_names[epi_func].lower() + "_sm" + sm_version
)
dispatch_body += SubstituteTemplate(CommonDispatchTemp, sm_dicts)
op_dicts = {}
op_dicts["dispatch_body"] = dispatch_body
op_dicts["op_name"] = camel_names[epi_func]
generated_code += SubstituteTemplate(CommonWrapperForPhi, op_dicts)
return generated_code
......@@ -46,16 +46,18 @@ typedef struct {
int dilation_w;
int oh;
int ow;
int groups;
const phi::GPUContext *ctx;
float alpha; // for leaky_relu use
int sm_version = 75;
} ConvAllParams;
// Below functions are provided by cutlass, they are called by phi.
void Conv2dBiasAddRelu(ConvAllParams params);
void Conv2dBiasRelu(ConvAllParams params);
void Conv2dBiasLeakyRelu(ConvAllParams params);
void Conv2dBiasSilu(ConvAllParams params);
void Conv2dBias(ConvAllParams params);
void Conv2dBiasAddRelu(const ConvAllParams &params);
void Conv2dBiasRelu(const ConvAllParams &params);
void Conv2dBiasLeakyRelu(const ConvAllParams &params);
void Conv2dBiasSilu(const ConvAllParams &params);
void Conv2dBias(const ConvAllParams &params);
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
......@@ -62,12 +62,14 @@ __global__ void naive_conv2d_kernel(const half *input,
int dilation_w,
int oh,
int ow,
int groups,
const half *residual,
float alpha, // for leaky_relu
OpType op_type) {
int M = batch * oh * ow;
int N = oc;
int K = ic * kh * kw;
int kc = ic / groups;
int K = kc * kh * kw;
int m_i = threadIdx.x + blockIdx.x * blockDim.x;
int n_i = threadIdx.y + blockIdx.y * blockDim.y;
if (m_i >= M || n_i >= N) return;
......@@ -76,19 +78,20 @@ __global__ void naive_conv2d_kernel(const half *input,
int oh_i = (m_i % (oh * ow)) / ow;
int ow_i = (m_i % (oh * ow)) % ow;
int oc_i = n_i;
int groups_i = (oc_i / (oc / groups));
struct logical_coord weight_shape = {oc, ic, kh, kw};
struct logical_coord weight_shape = {oc, kc, kh, kw};
struct logical_coord input_shape = {batch, ic, ih, iw};
int out_offset = m_i * N + n_i;
float *out_ptr = output + out_offset;
float sum = 0.f;
for (int k_i = 0; k_i < K; k_i++) {
int ic_i = k_i / (kh * kw);
int ic_i = k_i / (kh * kw) + groups_i * kc;
int kh_i = (k_i % (kh * kw)) / kw;
int kw_i = (k_i % (kh * kw)) % kw;
struct logical_coord weight_index = {oc_i, ic_i, kh_i, kw_i};
struct logical_coord weight_index = {oc_i, k_i / (kh * kw), kh_i, kw_i};
int ih_i = oh_i * stride_h - pad_h + kh_i * dilation_h;
int iw_i = ow_i * stride_w - pad_w + kw_i * dilation_w;
......@@ -127,7 +130,7 @@ __global__ void naive_conv2d_kernel(const half *input,
}
}
float conv2d_diff_gpu(ConvAllParams params, OpType op_type) {
float conv2d_diff_gpu(const ConvAllParams &params, OpType op_type) {
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
......@@ -146,6 +149,7 @@ float conv2d_diff_gpu(ConvAllParams params, OpType op_type) {
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
const half *residual = params.residual;
int groups = params.groups;
int oh = params.oh;
int ow = params.ow;
......@@ -186,6 +190,7 @@ float conv2d_diff_gpu(ConvAllParams params, OpType op_type) {
dilation_w,
oh,
ow,
groups,
residual,
params.alpha,
op_type);
......@@ -212,7 +217,7 @@ std::string OpType2String(OpType op_type) {
return "conv2d_bias_relu";
break;
case CONV2D_BIAS_SILU:
return "conv2d_bias_add_silu";
return "conv2d_bias_silu";
break;
case CONV2D_BIAS_ADD_RELU:
return "conv2d_bias_add_relu";
......@@ -227,7 +232,7 @@ std::string OpType2String(OpType op_type) {
int ProfileToGetBestConfig(
const std::vector<std::function<cutlass::Status(ConvAllParams)>> &all_func,
ConvAllParams params,
const ConvAllParams &params,
OpType op_type) {
constexpr int WARMUP = 10;
constexpr int REPEAT = 100;
......@@ -258,10 +263,11 @@ int ProfileToGetBestConfig(
if (elapsed_time < min_time && status == cutlass::Status::kSuccess) {
min_time = elapsed_time;
min_time_index = i;
// debug code
VLOG(3) << OpType2String(op_type) << ": tactic " << i << " has max diff "
<< conv2d_diff_gpu(params, op_type) << " compared with baseline,"
<< "cost_time: " << elapsed_time << "ms.";
}
// debug code
VLOG(3) << OpType2String(op_type) << ": tactic " << i << " has max diff "
<< conv2d_diff_gpu(params, op_type) << " compared with baseline.";
}
if (min_time_index < 0) {
......
......@@ -23,6 +23,7 @@
#include "cutlass/conv/device/implicit_gemm_convolution.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
......@@ -40,16 +41,17 @@ typedef enum {
CONV2D_BIAS_RELU,
CONV2D_BIAS_ADD_RELU,
CONV2D_BIAS_SILU,
CONV2D_BIAS_LEAKY_RELU
CONV2D_BIAS_LEAKY_RELU,
CONV2D_BIAS_SILU_ADD
} OpType;
// conv2d_diff_gpu calculate diff of cutlass output and baseline output, you can
// use them to debug. return value is the max diff between cutlass and baseline.
float conv2d_diff_gpu(ConvAllParams params, OpType op_type);
float conv2d_diff_gpu(const ConvAllParams& params, OpType op_type);
int ProfileToGetBestConfig(
const std::vector<std::function<cutlass::Status(ConvAllParams)>>& all_func,
ConvAllParams params,
const ConvAllParams& params,
OpType op_type);
} // namespace cutlass_internal
......
......@@ -43,12 +43,14 @@ void Conv2dFusionKernel(const Context& ctx,
CHECK_EQ(filter_dims.size() == 4UL, true);
CHECK_EQ(strides.size() == 2UL, true);
CHECK_EQ(dilations.size() == 2UL, true);
CHECK_EQ(groups == 1, true);
CHECK_EQ(padding_algorithm == "EXPLICIT", true);
const int batch = in_dims[0];
const int ic = in_dims[3];
const int ih = in_dims[1];
const int iw = in_dims[2];
CHECK_EQ(groups == 1, true);
CHECK_EQ(ic == groups * filter_dims[3], true);
int pad_h0 = 0;
int pad_h1 = 0;
int pad_w0 = 0;
......@@ -104,6 +106,7 @@ void Conv2dFusionKernel(const Context& ctx,
dilation_w,
oh,
ow,
groups,
&ctx};
if (residual) {
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import re
class TileDesc:
def __init__(self, Tshape, stages, Wshape, math_inst):
self.Tshape = Tshape
self.stages = stages
self.Wshape = Wshape
self.math_inst = math_inst
def SubstituteTemplate(template, values):
text = template
changed = True
while changed:
changed = False
for key, value in values.items():
regex = "\\$\\{%s\\}" % key
newtext = re.sub(regex, value, text)
if newtext != text:
changed = True
text = newtext
return text
......@@ -27,7 +27,7 @@ import paddle.inference as paddle_infer
class TestCutlassConv2dFusionOp1(CutlassAutoScanTest):
def sample_program_configs(self, *args, **kwargs):
def generate_input1(input_shape):
return np.random.random(input_shape).astype(np.float32)
return (np.random.random(input_shape) - 0.5).astype(np.float32)
def generate_weight(weight_shape):
return np.random.random(weight_shape).astype(np.float32)
......@@ -105,7 +105,9 @@ class TestCutlassConv2dFusionOp1(CutlassAutoScanTest):
"op_type": act,
"op_inputs": {"X": ["output_data0"]},
"op_outputs": {"Out": ["output_data1"]},
"op_attrs": {},
"op_attrs": {
"alpha": 2.0,
},
},
]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册