未验证 提交 0f6c5459 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference]add cutlass act set in conv_elementwise_add_act_fuse_pass (#48838)

* add cutlass act set in conv_elementwise_add_act_fuse_pass
上级 4c563e0b
...@@ -138,6 +138,19 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { ...@@ -138,6 +138,19 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
FusePassBase::Init("data_layout_transfer", graph); FusePassBase::Init("data_layout_transfer", graph);
auto *scope = param_scope(); auto *scope = param_scope();
// only float16 compute precision need insert transfer_layout.
bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_half");
bool cutlass_enable = false;
#ifdef PADDLE_WITH_CUTLASS
cutlass_enable = true;
#endif
if (!(is_fp16_precision && cutlass_enable)) return;
PADDLE_ENFORCE_EQ(graph->IsMainGraph(), PADDLE_ENFORCE_EQ(graph->IsMainGraph(),
true, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -169,14 +182,24 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { ...@@ -169,14 +182,24 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
if (data_format != "NCHW") return false; if (data_format != "NCHW") return false;
auto filter_names = op_node->Op()->Input("Filter"); auto filter_names = op_node->Op()->Input("Filter");
auto act_type = op_node->Op()->GetAttrIfExists<std::string>("activation");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
std::unordered_set<std::string> cutlass_act_set = {
"relu", "swish", "identity", "leaky_relu"};
if (!cutlass_act_set.count(act_type)) {
return false;
}
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc. // If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
for (const auto &filter_name : filter_names) { for (const auto &filter_name : filter_names) {
auto *filter_var = scope->FindLocalVar(filter_name); auto *filter_var = scope->FindLocalVar(filter_name);
const auto &filter_tensor = filter_var->Get<phi::DenseTensor>(); const auto &filter_tensor = filter_var->Get<phi::DenseTensor>();
if (filter_tensor.dims().size() == 4 && CHECK_EQ(filter_tensor.dims().size() == 4UL, true);
(filter_tensor.dims()[0] % 8 != 0 || int oc = filter_tensor.dims()[0];
filter_tensor.dims()[1] % 8 != 0)) { int ic = filter_tensor.dims()[1];
bool cutlass_can_support =
oc % CUTLASS_NHWC_ALIGNMENT == 0 && ic % CUTLASS_NHWC_ALIGNMENT == 0;
if (!cutlass_can_support) {
return false; return false;
} }
} }
...@@ -190,6 +213,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { ...@@ -190,6 +213,7 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
auto *op_desc = op_node->Op(); auto *op_desc = op_node->Op();
auto nhwc_attr = framework::Attribute(std::string("NHWC")); auto nhwc_attr = framework::Attribute(std::string("NHWC"));
op_desc->SetAttr("data_format", nhwc_attr); op_desc->SetAttr("data_format", nhwc_attr);
op_desc->SetType("conv2d_fusion_cutlass");
op_desc->Flush(); op_desc->Flush();
// transfer weights // transfer weights
......
...@@ -36,7 +36,8 @@ framework::proto::OpDesc PrepareOpDesc( ...@@ -36,7 +36,8 @@ framework::proto::OpDesc PrepareOpDesc(
const framework::proto::OpDesc& base_desc, const framework::proto::OpDesc& base_desc,
const std::string& bias, const std::string& bias,
const std::string& activation, const std::string& activation,
const std::string& output) { const std::string& output,
float alpha) {
auto proto = base_desc; auto proto = base_desc;
framework::OpDesc desc(proto, nullptr); framework::OpDesc desc(proto, nullptr);
desc.SetType("conv2d_fusion"); desc.SetType("conv2d_fusion");
...@@ -46,6 +47,8 @@ framework::proto::OpDesc PrepareOpDesc( ...@@ -46,6 +47,8 @@ framework::proto::OpDesc PrepareOpDesc(
desc.SetOutput("Output", {output}); desc.SetOutput("Output", {output});
desc.SetAttr("is_test", true); desc.SetAttr("is_test", true);
desc.SetAttr("use_cudnn", false); desc.SetAttr("use_cudnn", false);
// for leaky_relu use
desc.SetAttr("fuse_alpha", alpha);
desc.Flush(); desc.Flush();
return *desc.Proto(); return *desc.Proto();
} }
...@@ -118,6 +121,25 @@ ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() { ...@@ -118,6 +121,25 @@ ConvElementwiseAddActFusePass::ConvElementwiseAddActFusePass() {
.AddOutput("Out") .AddOutput("Out")
.IsTensor() .IsTensor()
.End(); .End();
AddOpCompat(OpCompat("swish"))
.AddInput("X")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End();
AddOpCompat(OpCompat("leaky_relu"))
.AddInput("X")
.IsTensor()
.End()
.AddAttr("alpha")
.IsType<float>()
.End()
.AddOutput("Out")
.IsTensor()
.End();
} }
void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
...@@ -137,8 +159,28 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -137,8 +159,28 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
std::unordered_set<std::string> cudnn_act_set({"identity", "relu"}); std::unordered_set<std::string> cudnn_act_set({"identity", "relu"});
#endif #endif
std::unordered_set<std::string> cutlass_act_set;
std::unordered_set<std::string> all_act_set = cudnn_act_set;
bool is_fp16_precision =
static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_half");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
if (is_fp16_precision) {
#ifdef PADDLE_WITH_CUTLASS
// cutlass now support these activations
// cutlass_act_set.insert("swish");
// cutlass_act_set.insert("relu");
// cutlass_act_set.insert("identity");
// cutlass_act_set.insert("leaky_relu");
all_act_set.insert(cutlass_act_set.begin(), cutlass_act_set.end());
#endif
}
patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name); patterns::ConvElementwiseaddAct pattern(gpd.mutable_pattern(), pattern_name);
pattern(x, cudnn_act_set); pattern(x, all_act_set);
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
...@@ -152,9 +194,27 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -152,9 +194,27 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
std::string bias_name = elementwise_add_in_y->Name(); std::string bias_name = elementwise_add_in_y->Name();
std::string act_op_type = act_op->Op()->Type(); std::string act_op_type = act_op->Op()->Type();
std::string act_op_out = act_out->Name(); std::string act_op_out = act_out->Name();
auto* scope = param_scope();
auto* filter_var = scope->FindLocalVar(conv_filter->Name());
auto* filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
CHECK_EQ(filter_tensor->dims().size() == 4UL, true);
// when this conv2d_fusion problem size is not supported by cutlass and not
// supported by cuDNN, we should not apply this pass
int oc = filter_tensor->dims()[0];
int ic = filter_tensor->dims()[1];
bool cutlass_can_fuse = oc % CUTLASS_NHWC_ALIGNMENT == 0 &&
ic % CUTLASS_NHWC_ALIGNMENT == 0 &&
cutlass_act_set.count(act_op_type);
bool cudnn_can_fuse = cudnn_act_set.count(act_op_type);
if (!cutlass_can_fuse && !cudnn_can_fuse) {
return;
}
float alpha = 0.f;
alpha = act_op->Op()->GetAttrIfExists<float>("alpha");
auto new_op_proto = auto new_op_proto =
PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out); PrepareOpDesc(base_op_desc, bias_name, act_op_type, act_op_out, alpha);
framework::OpDesc new_op_desc(new_op_proto, nullptr); framework::OpDesc new_op_desc(new_op_proto, nullptr);
// Create a new node for the fused op. // Create a new node for the fused op.
...@@ -195,4 +255,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass) ...@@ -195,4 +255,6 @@ REGISTER_PASS_CAPABILITY(conv_elementwise_add_act_fuse_pass)
.EQ("relu", 0) .EQ("relu", 0)
.EQ("sigmoid", 0) .EQ("sigmoid", 0)
.EQ("tanh", 0) .EQ("tanh", 0)
.EQ("identity", 0)); .EQ("identity", 0)
.LE("leaky_relu", 1)
.EQ("swish", 0));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册