未验证 提交 bd3b096a 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle-Inference] Add cutlass conv2d_depthwise (#51792)

* initial commit for cutlass_teller

* second commit for cutlass_teller

* add conv2d_depthwise python template

* add conv2d_depthwise cutlass template

* /zhoukangkang/paddle_cutlass/Paddle/paddle/fluid/framework/ir/cutlass_teller.h

* refine code in Conv2dFusionCanSupport

* add macro in cutlass_teller.h

* add 3x3 5x5 teller

* add groups not 1 or conv2d_depthwise teller

* 只生成ic是8的倍数的conv2d_depthwise 的kernel

* add EXPLICIT in cutlass_teller.h

* final commit

* add split_k_slices in conv2d_depthwise

* make stages == 2

* 重构部分代码

* add CutlassFusionType

* solve illegal memory

* make stride_h=stride_w && make dilation==1

* must check HasAttr(use_cutlass) before GetAttrIfExists

* add CONV2D_DEPTHWISE_BIAS_SILU to OpType2String

* modify decl.h and util.cu
上级 bafe287a
......@@ -13,10 +13,10 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/cutlass_teller.h"
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
......@@ -112,17 +112,6 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_mixed");
bool cutlass_enable = Get<bool>("use_cutlass");
#ifdef PADDLE_WITH_CUTLASS
const auto &prop = platform::GetDeviceProperties(Get<int>("gpu_device_id"));
int sm_version = prop.major * 10 + prop.minor;
// Now we only implement cutlass kernel on SM75.
if (sm_version == 75) {
} else {
cutlass_enable = false;
}
#endif
if (!is_fp16_precision) return;
PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
......@@ -152,26 +141,22 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
std::string target_op_type = "conv2d_fusion";
std::unordered_set<ir::Node *> valid_ops;
// Determine if this conv2d_fusion can run in cuDNN's NHWC mode,
// will not set or change any attribute in op_desc
auto cuDNNIsValid = [&](ir::Node *op_node) -> bool {
if (op_node->Op()->Type() != target_op_type) return false;
auto data_format =
op_node->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format != "NCHW") return false;
auto filter_names = op_node->Op()->Input("Filter");
constexpr int NHWC_ALIGNMENT = 8;
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
constexpr int CUDNN_ALIGNMENT = 8;
// If filter's channel is not multiple of CUDNN_ALIGNMENT, conv2d_fusion not
// run at nhwc.
for (const auto &filter_name : filter_names) {
if (weights_shape_nhwc.count(filter_name)) {
continue;
}
auto *filter_var = scope->FindLocalVar(filter_name);
const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
CHECK_EQ(filter_tensor.dims().size() == 4UL, true);
int oc = filter_tensor.dims()[0];
int ic = filter_tensor.dims()[1];
bool cutlass_can_support =
oc % NHWC_ALIGNMENT == 0 && ic % NHWC_ALIGNMENT == 0;
if (!cutlass_can_support) {
bool cudnn_can_support =
oc % CUDNN_ALIGNMENT == 0 && ic % CUDNN_ALIGNMENT == 0;
if (!cudnn_can_support) {
return false;
}
}
......@@ -179,39 +164,44 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
};
auto CutlassIsValid = [&](ir::Node *op_node) -> bool {
auto act_type = op_node->Op()->GetAttrIfExists<std::string>("activation");
// conv2d_fusion has two forms: conv + bias + act, conv + bias +
// elmentwise_add + act.
std::unordered_set<std::string> cutlass_cba_act_set = {
"relu", "swish", "identity", "leaky_relu"};
std::unordered_set<std::string> cutlass_cbaa_act_set = {"relu"};
bool is_residual = op_node->Op()->Input("ResidualData").size() >= 1UL;
if (is_residual) {
if (!cutlass_cbaa_act_set.count(act_type)) {
return false;
}
} else {
if (!cutlass_cba_act_set.count(act_type)) {
return false;
}
auto op_desc = op_node->Op();
bool use_cutlass = false;
if (op_desc->HasAttr("use_cutlass")) {
use_cutlass = op_desc->GetAttrIfExists<bool>("use_cutlass");
}
return true;
return use_cutlass && cutlass_enable;
};
for (auto *op_node : op_nodes) {
CHECK_EQ(op_node->IsOp(), true);
if (cuDNNIsValid(op_node)) {
// some common check.
if (op_node->Op()->Type() != target_op_type) {
continue;
}
auto filter_name = op_node->Op()->Input("Filter").front();
if (weights_shape_nhwc.count(filter_name)) {
continue;
}
auto data_format =
op_node->Op()->GetAttrIfExists<std::string>("data_format");
if (data_format != "NCHW") {
continue;
}
if (cuDNNIsValid(op_node) || CutlassIsValid(op_node)) {
valid_ops.insert(op_node);
auto *op_desc = op_node->Op();
op_desc->SetAttr("data_format", std::string{"NHWC"});
if (cutlass_enable && CutlassIsValid(op_node)) {
if (CutlassIsValid(op_node)) {
op_desc->SetType("conv2d_fusion_cutlass");
// conv2d_fusion_cutlass must have this attribute because of signature.
if (!op_desc->HasAttr("fuse_alpha")) {
op_desc->SetAttr("fuse_alpha", 0.f);
}
}
op_desc->SetAttr("data_format", std::string{"NHWC"});
op_desc->Flush();
// transfer weights
......
......@@ -14,7 +14,7 @@
#include "paddle/fluid/framework/ir/conv_elementwise_add2_act_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/ir/cutlass_teller.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......@@ -138,8 +138,21 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> cudnn_act_set({"identity", "relu"});
#endif
std::unordered_set<std::string> cutlass_act_set =
CutlassTeller::Instance()->CbaaAct(Get<int>("gpu_device_id"));
std::unordered_set<std::string> all_act_set = cudnn_act_set;
bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_mixed");
bool cutlass_enable = Get<bool>("use_cutlass");
if (is_fp16_precision && cutlass_enable) {
all_act_set.insert(cutlass_act_set.begin(), cutlass_act_set.end());
}
patterns::ConvElementwiseadd2Act pattern(gpd.mutable_pattern(), pattern_name);
pattern(x, cudnn_act_set);
pattern(x, all_act_set);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
......@@ -166,9 +179,21 @@ void ConvElementwiseAdd2ActFusePass::ApplyImpl(ir::Graph* graph) const {
return;
}
auto* scope = param_scope();
bool cutlass_can_fuse = CutlassTeller::Instance()->CbaaCanSupport(
conv_op->Op(), scope, act_op_type, Get<int>("gpu_device_id"));
bool cudnn_can_fuse = cudnn_act_set.count(act_op_type);
if (!cutlass_can_fuse && !cudnn_can_fuse) {
return;
}
auto new_op_proto = PrepareOpDesc(
base_op_desc, bias_name, bias1_name, act_op_type, act_op_out);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
if (cutlass_can_fuse && cutlass_enable && is_fp16_precision) {
new_op_desc.SetAttr("use_cutlass", true);
}
// Create a new node for the fused op.
auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.h"
#include "paddle/fluid/framework/ir/cutlass_teller.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......@@ -159,29 +159,17 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> cudnn_act_set({"identity", "relu"});
#endif
std::unordered_set<std::string> cutlass_act_set;
std::unordered_set<std::string> cutlass_act_set =
CutlassTeller::Instance()->CbaAct(Get<int>("gpu_device_id"));
std::unordered_set<std::string> all_act_set = cudnn_act_set;
bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_mixed");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
bool cutlass_enable = Get<bool>("use_cutlass");
if (is_fp16_precision && cutlass_enable) {
#ifdef PADDLE_WITH_CUTLASS
const auto& prop = platform::GetDeviceProperties(Get<int>("gpu_device_id"));
int sm_version = prop.major * 10 + prop.minor;
// Now we only implement cutlass kernel on SM75.
if (sm_version == 75) {
// Cutlass now support these cba activations.
cutlass_act_set.insert("swish");
cutlass_act_set.insert("relu");
cutlass_act_set.insert("identity");
cutlass_act_set.insert("leaky_relu");
}
all_act_set.insert(cutlass_act_set.begin(), cutlass_act_set.end());
#endif
}
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
......@@ -200,17 +188,12 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name();
auto* scope = param_scope();
auto* filter_var = scope->FindLocalVar(conv_filter->Name());
auto* filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
CHECK_EQ(filter_tensor->dims().size() == 4UL, true);
// When this conv2d_fusion problem size is not supported by cutlass and not
// supported by cuDNN, we should not apply this pass.
int oc = filter_tensor->dims()[0];
int ic = filter_tensor->dims()[1];
bool cutlass_can_fuse = oc % CUTLASS_NHWC_ALIGNMENT == 0 &&
ic % CUTLASS_NHWC_ALIGNMENT == 0 &&
cutlass_act_set.count(act_op_type);
bool cutlass_can_fuse = CutlassTeller::Instance()->CbaCanSupport(
conv_op->Op(), scope, act_op_type, Get<int>("gpu_device_id"));
bool cudnn_can_fuse = cudnn_act_set.count(act_op_type);
// When this conv2d_fusion specified by problem size and act type is not
// supported by cutlass and not supported by cuDNN, we should not apply this
// pass.
if (!cutlass_can_fuse && !cudnn_can_fuse) {
return;
}
......@@ -221,7 +204,9 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
auto new_op_proto =
PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out, alpha);
framework::OpDesc new_op_desc(new_op_proto, nullptr);
if (cutlass_can_fuse && cutlass_enable && is_fp16_precision) {
new_op_desc.SetAttr("use_cutlass", true);
}
// Create a new node for the fused op.
auto* new_conv_op = graph->CreateOpNode(&new_op_desc);
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/framework/ir/conv_elementwise_add_fuse_pass.h"
#include "paddle/fluid/framework/ir/cutlass_teller.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......@@ -116,6 +116,19 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
new_op_desc.SetOutput("Output", {output_name});
new_op_desc.SetAttr("is_test", true);
new_op_desc.SetAttr("use_cudnn", false);
bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_mixed");
bool cutlass_enable = Get<bool>("use_cutlass");
auto* scope = param_scope();
bool cutlass_can_fuse = CutlassTeller::Instance()->CbaCanSupport(
conv_op->Op(), scope, act_type, Get<int>("gpu_device_id"));
if (cutlass_can_fuse && cutlass_enable && is_fp16_precision) {
new_op_desc.SetAttr("use_cutlass", true);
}
auto* elementwise_add_op_desc = elementwise_add_op->Op();
auto out_threshold_attr =
elementwise_add_op_desc->GetNullableAttr("out_threshold");
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_set>
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
namespace paddle {
namespace framework {
namespace ir {
typedef enum {
cba,
cbaa,
} CutlassFusionType;
class CutlassTeller {
public:
static CutlassTeller *Instance() {
static CutlassTeller global;
return &global;
}
#if defined(PADDLE_WITH_CUTLASS)
// Determine this NCHW conv2d + bias can be fused with activation by cutlass?
// will not set or change any attribute in op_desc
bool CbaCanSupport(OpDesc *op_desc,
Scope *scope,
std::string act_type,
int device_id) {
auto strides = op_desc->GetAttrIfExists<std::vector<int>>("strides");
auto dilations = op_desc->GetAttrIfExists<std::vector<int>>("dilations");
CHECK_EQ(strides.size() == 2UL, true);
CHECK_EQ(dilations.size() == 2UL, true);
int stride_h = strides[0];
int stride_w = strides[1];
int dilation_h = dilations[0];
int dilation_w = dilations[1];
auto filter_names = op_desc->Input("Filter");
for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name);
const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
CHECK_EQ(filter_tensor.dims().size() == 4UL, true);
auto groups = op_desc->GetAttrIfExists<int>("groups");
int oc = filter_tensor.dims()[0];
int kc = filter_tensor.dims()[1];
int kh = filter_tensor.dims()[2];
int kw = filter_tensor.dims()[3];
// For convience, we only support EXPLICIT
auto padding_algorithm =
op_desc->GetAttrIfExists<std::string>("padding_algorithm");
if (padding_algorithm != "EXPLICIT") {
return false;
}
if (!Conv2dCanSupport(oc,
kc,
kh,
kw,
stride_h,
stride_w,
dilation_h,
dilation_w,
groups,
act_type,
device_id,
CutlassFusionType::cba)) {
return false;
}
}
return true;
}
// Determine this NCHW conv2d + bias + elewise_add + act can be fused by
// cutlass? will not set or change any attribute in op_desc
bool CbaaCanSupport(OpDesc *op_desc,
Scope *scope,
std::string act_type,
int device_id) {
auto strides = op_desc->GetAttrIfExists<std::vector<int>>("strides");
auto dilations = op_desc->GetAttrIfExists<std::vector<int>>("dilations");
CHECK_EQ(strides.size() == 2UL, true);
CHECK_EQ(dilations.size() == 2UL, true);
int stride_h = strides[0];
int stride_w = strides[1];
int dilation_h = dilations[0];
int dilation_w = dilations[1];
auto filter_names = op_desc->Input("Filter");
for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name);
const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
CHECK_EQ(filter_tensor.dims().size() == 4UL, true);
auto groups = op_desc->GetAttrIfExists<int>("groups");
int oc = filter_tensor.dims()[0];
int kc = filter_tensor.dims()[1];
int kh = filter_tensor.dims()[2];
int kw = filter_tensor.dims()[3];
// For convience, we only support EXPLICIT
auto padding_algorithm =
op_desc->GetAttrIfExists<std::string>("padding_algorithm");
if (padding_algorithm != "EXPLICIT") {
return false;
}
if (!Conv2dCanSupport(oc,
kc,
kh,
kw,
stride_h,
stride_w,
dilation_h,
dilation_w,
groups,
act_type,
device_id,
CutlassFusionType::cbaa)) {
return false;
}
}
return true;
}
// Determine whether this conv can be fused with the activation by cutlass
// backend.
bool Conv2dCanSupport(int oc,
int kc,
int kh,
int kw,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
int groups,
std::string activation,
int device_id,
CutlassFusionType fuse_type) {
int sm_version = platform::GetGPUComputeCapability(device_id);
int ic = kc * groups;
if (!cutlass_sm.count(sm_version)) {
return false;
}
// To prevent generating too many cutlass code,
// we only allow oc and ic is divisable by CUTLASS_NHWC_ALIGNMENT
if (groups == 1) {
if (oc % CUTLASS_NHWC_ALIGNMENT != 0 ||
ic % CUTLASS_NHWC_ALIGNMENT != 0) {
return false;
}
// conv + bias + act
if (fuse_type == CutlassFusionType::cba &&
!cba_act_set.count(activation)) {
return false;
}
// conv + bias + elementwise_add + act
if (fuse_type == CutlassFusionType::cbaa &&
!cbaa_act_set.count(activation)) {
return false;
}
} else if (groups == ic && ic == oc) {
// return false;
// conv2d_depthwise not support residual input
if (fuse_type != CutlassFusionType::cba) {
return false;
}
// Now we only 3x3s1s2, 5x5s1s2
if (!(kh == 3 && kw == 3) || (kh == 5 && kw == 5)) {
return false;
}
if (!(stride_h == 1 || stride_h == 2)) {
return false;
}
if (stride_h != stride_w) {
return false;
}
if (dilation_h != 1) {
return false;
}
if (dilation_w != 1) {
return false;
}
// Now we only allow ic % 8 == 0, because of cutlass.
if (ic % 8 != 0) {
return false;
}
// conv2d_depthwise + bias + act
if (!cdba_act_set.count(activation)) {
return false;
}
} else {
// only support groups == 1 or conv2d_depthwise
return false;
}
return true;
}
// Return the supported activation set by cutlass conv + bias + act pattern
std::unordered_set<std::string> CbaAct(int device_id) {
int sm_version = platform::GetGPUComputeCapability(device_id);
if (cutlass_sm.count(sm_version)) {
return cba_act_set;
} else {
return {};
}
}
// Return the supported activation set by cutlass conv + bias + act pattern
std::unordered_set<std::string> CbaaAct(int device_id) {
int sm_version = platform::GetGPUComputeCapability(device_id);
if (cutlass_sm.count(sm_version)) {
return cbaa_act_set;
} else {
return {};
}
}
#else
bool CbaaCanSupport(OpDesc *op_desc,
Scope *scope,
std::string act_type,
int device_id) {
return false;
}
bool CbaCanSupport(OpDesc *op_desc,
Scope *scope,
std::string act_type,
int device_id) {
return false;
}
bool Conv2dCanSupport(int oc,
int kc,
int kh,
int kw,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
int groups,
std::string activation,
int device_id,
CutlassFusionType fuse_type) {
return false;
}
std::unordered_set<std::string> CbaAct(int device_id) { return {}; }
std::unordered_set<std::string> CbaaAct(int device_id) { return {}; }
#endif
static const int CUTLASS_NHWC_ALIGNMENT = 8;
const std::unordered_set<int> cutlass_sm = {
75,
};
const std::unordered_set<std::string> cba_act_set = {
"relu", "swish", "identity", "leaky_relu", "sigmoid"};
// conv2d_depthwise act
const std::unordered_set<std::string> cdba_act_set = {
"identity", "relu", "swish", "sigmoid"};
const std::unordered_set<std::string> cbaa_act_set = {"relu"};
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -122,6 +122,7 @@ if(WITH_CUTLASS)
"${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/conv2d/generated"
COMMAND ${PYTHON_EXECUTABLE} "conv2d_bias_act.py"
COMMAND ${PYTHON_EXECUTABLE} "conv2d_bias_residual.py"
COMMAND ${PYTHON_EXECUTABLE} "conv2d_depthwise_bias_act.py"
WORKING_DIRECTORY "${CMAKE_CURRENT_SOURCE_DIR}/fusion/cutlass/conv2d")
execute_process(
......
......@@ -36,8 +36,10 @@ cba_header = '''
#include "cutlass/epilogue/thread/linear_combination_leaky_relu.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
......@@ -46,7 +48,7 @@ namespace cutlass_internal {
# This is a cutlass kernel, will be many these like kernels
dict_for_declare_part = {
"conv_kind_name": "Fprop",
"conv_kind_name": "DefaultConv2dFprop",
"epi_part": "${epi_func}< ${element_c}, ${epilogue_vector_length}, ${element_accum}, ${element_epilogue}>",
}
......@@ -78,6 +80,7 @@ class CbaAct(enum.Enum):
Relu = 2
Silu = 3
LeakyRelu = 4
Sigmoid = 5
# Some global variables used, now we only support these activations.
......@@ -86,6 +89,7 @@ SupportedAct = [
CbaAct.Relu,
CbaAct.Silu,
CbaAct.LeakyRelu,
CbaAct.Sigmoid,
]
ActTag = {
......@@ -93,6 +97,7 @@ ActTag = {
SupportedAct[1]: 'cutlass::epilogue::thread::LinearCombinationRelu',
SupportedAct[2]: 'cutlass::epilogue::thread::LinearCombinationSilu',
SupportedAct[3]: 'cutlass::epilogue::thread::LinearCombinationLeakyRelu',
SupportedAct[4]: 'cutlass::epilogue::thread::LinearCombinationSigmoid',
}
UnderScoreName = {
......@@ -100,6 +105,7 @@ UnderScoreName = {
SupportedAct[1]: "conv2d_bias_relu",
SupportedAct[2]: "conv2d_bias_silu",
SupportedAct[3]: "conv2d_bias_leaky_relu",
SupportedAct[4]: "conv2d_bias_sigmoid",
}
CamelName = {
......@@ -107,6 +113,7 @@ CamelName = {
SupportedAct[1]: "Conv2dBiasRelu",
SupportedAct[2]: "Conv2dBiasSilu",
SupportedAct[3]: "Conv2dBiasLeakyRelu",
SupportedAct[4]: "Conv2dBiasSigmoid",
}
# Generate sm75 TensorOp conv code.
......@@ -141,12 +148,12 @@ def generate_sm75_1688():
]
math_instructions = [
(
"16,8,8",
"cutlass::half_t",
"cutlass::half_t",
"cutlass::half_t",
),
# (
# "16,8,8",
# "cutlass::half_t",
# "cutlass::half_t",
# "cutlass::half_t",
# ),
(
"16,8,8",
"cutlass::half_t",
......@@ -161,6 +168,7 @@ def generate_sm75_1688():
kernel_dict["align_b"] = "8"
# this should divided by oc
kernel_dict["epilogue_vector_length"] = "8"
kernel_dict["split_k_slices"] = "1"
sm75_code = ""
for epi_func in SupportedAct:
......
......@@ -44,7 +44,7 @@ namespace cutlass_internal {
# This is a cutlass kernel, will be many these like kernels
dict_for_declare_part = {
"conv_kind_name": "FpropWithBroadcast",
"conv_kind_name": "DefaultConv2dFpropWithBroadcast",
"epi_part": "cutlass::epilogue::thread::LinearCombinationResidualBlock< ${element_c}, ${element_accum}, ${element_epilogue}, ${element_residul}, ${epilogue_vector_length}, ${act1}, ${binary}, ${act2}>",
}
......@@ -129,12 +129,12 @@ def generate_sm75_1688():
]
math_instructions = [
(
"16,8,8",
"cutlass::half_t",
"cutlass::half_t",
"cutlass::half_t",
),
# (
# "16,8,8",
# "cutlass::half_t",
# "cutlass::half_t",
# "cutlass::half_t",
# ),
(
"16,8,8",
"cutlass::half_t",
......@@ -148,6 +148,7 @@ def generate_sm75_1688():
kernel_dict["align_a"] = "8"
kernel_dict["align_b"] = "8"
kernel_dict["epilogue_vector_length"] = "8"
kernel_dict["split_k_slices"] = "1"
sm75_code = ""
for epi_res_block in SupportedEpilogue:
......
......@@ -24,7 +24,7 @@ from util import SubstituteTemplate
CommonCutlassConvKernelDeclare = """
cutlass::Status ${kernel_func_name}(const ConvAllParams& params) {
using kernel_base =
typename cutlass::conv::kernel::DefaultConv2d${conv_kind_name}<
typename cutlass::conv::kernel::${conv_kind_name}<
${element_a},
${layout_a},
${element_b},
......@@ -71,6 +71,7 @@ cutlass::Status ${kernel_func_name}(const ConvAllParams& params) {
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
int split_k_slices = ${split_k_slices};
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic / groups},
......@@ -79,7 +80,7 @@ cutlass::Status ${kernel_func_name}(const ConvAllParams& params) {
{dilation_h, dilation_w},
{batch, oh, ow, oc},
cutlass::conv::Mode::kCrossCorrelation,
1,
split_k_slices,
groups);
"""
......@@ -183,7 +184,7 @@ CommonTail = '''
'''
# wrap different sm versions into a function
# Wrap different sm versions into a function called by phi
def GenerateFunctionForPhi(
sm_versions, support_epi_funcs, underscore_names, camel_names
):
......@@ -202,3 +203,20 @@ def GenerateFunctionForPhi(
op_dicts["op_name"] = camel_names[epi_func]
generated_code += SubstituteTemplate(CommonWrapperForPhi, op_dicts)
return generated_code
# We modify some template parameters based on CommonCutlassConvKernelDeclare.
CommonCutlassConv2dDepthwiseKernelDeclare = (
CommonCutlassConvKernelDeclare.replace(
"${align_a}", "cutlass::MatrixShape<${strided_shape}>"
)
.replace("${align_b}", "cutlass::MatrixShape<${dilation_shape}>")
.replace("ImplicitGemmConvolution", "DirectConvolution")
.replace(
"cutlass::gemm::GemmShape<${Tshape}>,",
'''cutlass::gemm::GemmShape<${Tshape}>,
cutlass::conv::TensorNHWCShape<${T_output_shape}>,
cutlass::MatrixShape<${filter_shape}>,
''',
)
)
......@@ -58,6 +58,13 @@ void Conv2dBiasRelu(const ConvAllParams &params);
void Conv2dBiasLeakyRelu(const ConvAllParams &params);
void Conv2dBiasSilu(const ConvAllParams &params);
void Conv2dBias(const ConvAllParams &params);
void Conv2dBiasSigmoid(const ConvAllParams &params);
void Conv2dDepthwiseBias(const ConvAllParams &params);
void Conv2dDepthwiseBiasRelu(const ConvAllParams &params);
void Conv2dDepthwiseBiasSigmoid(const ConvAllParams &params);
void Conv2dDepthwiseBiasSilu(const ConvAllParams &params);
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
sys.path.append("../")
import enum
from conv2d_common import (
CommonConvFunction,
CommonCutlassConv2dDepthwiseKernelDeclare,
CommonCutlassConvKernelExecute,
CommonTail,
)
from util import SubstituteTemplate
# this is a file's header part
cdba_header = '''
// Generated by conv2d_depthwise_bias_act.py - Do not edit.
#include <mutex>
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
#include <stdio.h>
#include <algorithm>
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/conv/kernel/default_depthwise_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "cutlass/conv/device/direct_convolution.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
'''
# This is a cutlass kernel, will be many these like kernels
dict_for_declare_part = {
"conv_kind_name": "DefaultDepthwiseDirect2dConvFprop",
"epi_part": "${epi_func}< ${element_c}, ${epilogue_vector_length}, ${element_accum}, ${element_epilogue}>",
"swizzling_functor": '''cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle<${swizzling_shape}>''',
}
cba_kernel_no_alpha = (
SubstituteTemplate(
CommonCutlassConv2dDepthwiseKernelDeclare, dict_for_declare_part
)
+ '''
size_t filter_size = oc * kh * kw * kc * sizeof(half);
phi::Allocator::AllocationPtr filter_gpu_ptrs_data =
phi::memory_utils::Alloc(
params.ctx->GetPlace(),
filter_size,
phi::Stream(reinterpret_cast<phi::StreamId>(params.ctx->stream())));
void *filter_workspace = filter_gpu_ptrs_data->ptr();
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)input, {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)weight, {kc, kc * kw, kc * kw * kh}},
{(cutlass::half_t *)bias, {0, 0, 0}},
{(cutlass::half_t *)output, {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f},
{(cutlass::half_t *)filter_workspace, {kc, kc * kw, kc * kw * kh}},
};
'''
+ CommonCutlassConvKernelExecute
)
class CbaAct(enum.Enum):
Identity = 1
Relu = 2
Sigmoid = 3
Silu = 4
# Some global variables used, now we only support these activations.
SupportedAct = [CbaAct.Identity, CbaAct.Relu, CbaAct.Sigmoid, CbaAct.Silu]
ActTag = {
SupportedAct[0]: 'cutlass::epilogue::thread::LinearCombination',
SupportedAct[1]: 'cutlass::epilogue::thread::LinearCombinationRelu',
SupportedAct[2]: 'cutlass::epilogue::thread::LinearCombinationSigmoid',
SupportedAct[3]: 'cutlass::epilogue::thread::LinearCombinationSilu',
}
UnderScoreName = {
SupportedAct[0]: "conv2d_depthwise_bias",
SupportedAct[1]: "conv2d_depthwise_bias_relu",
SupportedAct[2]: "conv2d_depthwise_bias_sigmoid",
SupportedAct[3]: "conv2d_depthwise_bias_silu",
}
CamelName = {
SupportedAct[0]: "Conv2dDepthwiseBias",
SupportedAct[1]: "Conv2dDepthwiseBiasRelu",
SupportedAct[2]: "Conv2dDepthwiseBiasSigmoid",
SupportedAct[3]: "Conv2dDepthwiseBiasSilu",
}
def intlist2str(input):
return_str = ""
for i in range(len(input)):
return_str += str(input[i])
if i != len(input) - 1:
return_str += ","
return return_str
# Generate simt conv2d_depthwsie code.
def generate_conv2d_depthwise():
kernel_dict = {
"element_a": "cutlass::half_t",
"layout_a": "cutlass::layout::TensorNHWC",
"element_b": "cutlass::half_t",
"layout_b": "cutlass::layout::TensorNHWC",
"element_c": "cutlass::half_t",
"layout_c": "cutlass::layout::TensorNHWC",
"element_accum": "cutlass::half_t",
"opcode_class": "cutlass::arch::OpClassSimt",
"arch": "cutlass::arch::Sm70",
"Ishape": "1,1,1",
"stages": "2",
# alpha is always float!
"element_epilogue": "float",
"math_operator": "cutlass::arch::OpMultiplyAdd",
"iterator_algorithm": "cutlass::conv::IteratorAlgorithm::kFixedStrideDilation",
"stride_support": "cutlass::conv::StrideSupport::kStrided",
"dilation_shape": "1, 1",
}
# this should divided by oc
kernel_dict["epilogue_vector_length"] = "4"
all_code = ""
for epi_func in SupportedAct:
op_dict = {}
# Because conv2d_depthwise is not related to the sm version,
# so "func_name" are directly called by phi, we camel its name.
op_dict["func_name"] = CamelName[epi_func]
# enum_op_name is consistent with OpType in conv2d_util.h
op_dict["enum_op_name"] = UnderScoreName[epi_func].upper()
# For a function, we record all its kernels into a std::vector in C++ code
all_kernel_names = ""
kernel_dict["epi_func"] = ActTag[epi_func]
suffix = 0
filter_shapes = [[3, 3], [5, 5]]
stride_shapes = ["1,1", "2,2"]
# set [1,2,4,8] will generate too many kernels!
# Now only set [8]
for vec_length in ["8"]:
kernel_dict["epilogue_vector_length"] = vec_length
for filter_shape in filter_shapes:
for stride_shape in stride_shapes:
tiles = [
# [out_h, out_w, groups_per_cta, warp_m]
# out_h, out_w : per cta would process
# groups_per_cta: per cta would process
# warp_m: per warp would process
[8, 8, 16, 16],
# [8, 16, 16, 16],
# [16, 8, 16, 16],
[8, 8, 32, 16],
# [8, 16, 32, 16],
# [16, 8, 32, 16],
]
filter_size = filter_shape[0] * filter_shape[1]
for tile in tiles:
# per cta would process [1,out_h,out_w,groups_per_cta] output
kernel_dict["T_output_shape"] = intlist2str(
[1, tile[0], tile[1], tile[2]]
)
# per cta would process from the view of gemm
kernel_dict["Tshape"] = intlist2str(
[tile[0] * tile[1], tile[2], filter_size]
)
kernel_dict["Wshape"] = intlist2str(
[tile[3], tile[2], filter_size]
)
kernel_dict["swizzling_shape"] = intlist2str(
[1, 1, tile[0], tile[1]]
)
kernel_dict["split_k_slices"] = "(oh * ow + 63) / 64"
kernel_dict["filter_shape"] = intlist2str(filter_shape)
kernel_dict["strided_shape"] = stride_shape
kernel_dict["kernel_func_name"] = (
UnderScoreName[epi_func].lower() + "_" + str(suffix)
)
suffix += 1
all_code += SubstituteTemplate(
cba_kernel_no_alpha, kernel_dict
)
all_kernel_names += (
kernel_dict["kernel_func_name"] + ", \n"
)
# generate op code
op_dict["all_kernel_func_name"] = all_kernel_names
all_code += SubstituteTemplate(CommonConvFunction, op_dict)
return all_code
if __name__ == "__main__":
all_code = cdba_header
all_code += generate_conv2d_depthwise()
all_code += CommonTail
with open("generated/conv2d_depthwise_bias_act.cu", "w") as f:
f.write(all_code)
f.close()
......@@ -112,12 +112,15 @@ __global__ void naive_conv2d_kernel(const half *input,
switch (op_type) {
case CONV2D_BIAS:
case CONV2D_DEPTHWISE_BIAS:
*out_ptr = x;
break;
case CONV2D_BIAS_RELU:
case CONV2D_DEPTHWISE_BIAS_RELU:
*out_ptr = x > 0 ? x : 0;
break;
case CONV2D_BIAS_SILU:
case CONV2D_DEPTHWISE_BIAS_SILU:
*out_ptr = x * (1.f / (1 + exp(-x)));
break;
case CONV2D_BIAS_ADD_RELU:
......@@ -127,6 +130,10 @@ __global__ void naive_conv2d_kernel(const half *input,
case CONV2D_BIAS_LEAKY_RELU:
*out_ptr = x > 0 ? x : (x * alpha);
break;
case CONV2D_BIAS_SIGMOID:
case CONV2D_DEPTHWISE_BIAS_SIGMOID:
*out_ptr = 1.f / (1.f + std::exp(-x));
break;
default:
break;
}
......@@ -221,11 +228,22 @@ std::string OpType2String(OpType op_type) {
case CONV2D_BIAS_SILU:
return "conv2d_bias_silu";
break;
case CONV2D_BIAS_SIGMOID:
return "conv2d_bias_sigmoid";
break;
case CONV2D_BIAS_ADD_RELU:
return "conv2d_bias_add_relu";
break;
case CONV2D_BIAS_LEAKY_RELU:
return "conv2d_bias_leaky_relu";
case CONV2D_DEPTHWISE_BIAS:
return "conv2d_depthwise_bias";
case CONV2D_DEPTHWISE_BIAS_RELU:
return "conv2d_depthwise_bias_relu";
case CONV2D_DEPTHWISE_BIAS_SIGMOID:
return "conv2d_depthwise_bias_sigmoid";
case CONV2D_DEPTHWISE_BIAS_SILU:
return "conv2d_depthwise_bias_silu";
default:
break;
}
......@@ -245,6 +263,11 @@ int ProfileToGetBestConfig(
auto func = all_func[i];
// When func has large diff, we will make it nullptr.
if (!func) continue;
cudaMemset(params.output,
0,
sizeof(half) * params.batch * params.oc * params.oh * params.ow);
status = func(params);
if (status != cutlass::Status::kSuccess) continue;
for (int ii = 0; ii < WARMUP; ii++) {
status = func(params);
......
......@@ -44,7 +44,12 @@ typedef enum {
CONV2D_BIAS_ADD_RELU,
CONV2D_BIAS_SILU,
CONV2D_BIAS_LEAKY_RELU,
CONV2D_BIAS_SILU_ADD
CONV2D_BIAS_SIGMOID,
CONV2D_BIAS_SILU_ADD,
CONV2D_DEPTHWISE_BIAS,
CONV2D_DEPTHWISE_BIAS_RELU,
CONV2D_DEPTHWISE_BIAS_SIGMOID,
CONV2D_DEPTHWISE_BIAS_SILU,
} OpType;
// conv2d_diff_gpu calculate diff of cutlass output and baseline output, you can
......
......@@ -49,7 +49,7 @@ void Conv2dFusionKernel(const Context& ctx,
const int ic = in_dims[3];
const int ih = in_dims[1];
const int iw = in_dims[2];
CHECK_EQ(groups == 1, true);
CHECK_EQ(ic == groups * filter_dims[3], true);
int pad_h0 = 0;
int pad_h1 = 0;
......@@ -109,6 +109,30 @@ void Conv2dFusionKernel(const Context& ctx,
groups,
&ctx};
// conv2d_depthwise
if (groups == ic && ic == oc) {
// cutlass conv2d_depthwise not support residual
if (residual) {
CHECK_EQ(residual->data<T>() == nullptr, true);
}
if (activation == "relu") {
Conv2dDepthwiseBiasRelu(params);
} else if (activation == "identity") {
Conv2dDepthwiseBias(params);
} else if (activation == "sigmoid") {
Conv2dDepthwiseBiasSigmoid(params);
} else if (activation == "swish") {
Conv2dDepthwiseBiasSilu(params);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Cutlass conv2d_depthwise does not support this activation: %s.",
activation.c_str()));
}
return;
}
// below: conv2d_fusion && groups == 1
CHECK_EQ(groups == 1, true);
if (residual) {
if (activation == "relu") {
params.residual = reinterpret_cast<const half*>(residual->data<T>());
......@@ -126,6 +150,8 @@ void Conv2dFusionKernel(const Context& ctx,
} else if (activation == "leaky_relu") {
params.alpha = fuse_alpha;
Conv2dBiasLeakyRelu(params);
} else if (activation == "sigmoid") {
Conv2dBiasSigmoid(params);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Cutlass does not support this activation: %s.", activation.c_str()));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册