diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 088847d1f6f3b93ff9465eda4d70c6f2c421925b..06323119a7dc64a54682b201ec30d2c0cf03872b 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -105,6 +105,7 @@ pass_library(delete_fill_constant_op_pass inference) pass_library(constant_folding_pass inference) pass_library(auto_mixed_precision_pass inference) pass_library(conv2d_fusion_layout_transfer_pass inference) +pass_library(silu_fuse_pass inference) pass_library(simplify_with_basic_ops_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base) diff --git a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc index efed7dd6e637bc7e9421b3d4afb2090a1c47336c..dd4e0735600bec2a560a00faf338a944bacf702e 100644 --- a/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc +++ b/paddle/fluid/framework/ir/conv2d_fusion_layout_transfer_pass.cc @@ -143,10 +143,16 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { static_cast(Get("model_precision")) == phi::DataType::FLOAT16 || Get("enable_gpu_mixed"); - bool cutlass_enable = false; + bool cutlass_enable = Get("use_cutlass"); #ifdef PADDLE_WITH_CUTLASS - cutlass_enable = true; + const auto &prop = platform::GetDeviceProperties(Get("gpu_device_id")); + int sm_version = prop.major * 10 + prop.minor; + // Now we only implement cutlass kernel on SM75. + if (sm_version == 75) { + } else { + cutlass_enable = false; + } #endif if (!(is_fp16_precision && cutlass_enable)) return; @@ -184,10 +190,21 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const { auto filter_names = op_node->Op()->Input("Filter"); auto act_type = op_node->Op()->GetAttrIfExists("activation"); constexpr int CUTLASS_NHWC_ALIGNMENT = 8; - std::unordered_set cutlass_act_set = { + // conv2d_fusion has two forms: conv + bias + act, conv + bias + + // elmentwise_add + act. + std::unordered_set cutlass_cba_act_set = { "relu", "swish", "identity", "leaky_relu"}; - if (!cutlass_act_set.count(act_type)) { - return false; + std::unordered_set cutlass_cbaa_act_set = {"relu"}; + bool is_residual = op_node->Op()->Input("ResidualData").size() >= 1UL; + + if (is_residual) { + if (!cutlass_cbaa_act_set.count(act_type)) { + return false; + } + } else { + if (!cutlass_cba_act_set.count(act_type)) { + return false; + } } // If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc. diff --git a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc index 2f527ff1e707bb986aef0da8d721ab8920d6d048..ba18b04d9d04576532a786e940efc02b6d349fd3 100644 --- a/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc +++ b/paddle/fluid/framework/ir/conv_elementwise_add_act_fuse_pass.cc @@ -167,14 +167,19 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { phi::DataType::FLOAT16 || Get("enable_gpu_mixed"); constexpr int CUTLASS_NHWC_ALIGNMENT = 8; - if (is_fp16_precision) { + bool cutlass_enable = Get("use_cutlass"); + if (is_fp16_precision && cutlass_enable) { #ifdef PADDLE_WITH_CUTLASS - // cutlass now support these activations - // cutlass_act_set.insert("swish"); - // cutlass_act_set.insert("relu"); - // cutlass_act_set.insert("identity"); - // cutlass_act_set.insert("leaky_relu"); - + const auto& prop = platform::GetDeviceProperties(Get("gpu_device_id")); + int sm_version = prop.major * 10 + prop.minor; + // Now we only implement cutlass kernel on SM75. + if (sm_version == 75) { + // Cutlass now support these cba activations. + cutlass_act_set.insert("swish"); + cutlass_act_set.insert("relu"); + cutlass_act_set.insert("identity"); + cutlass_act_set.insert("leaky_relu"); + } all_act_set.insert(cutlass_act_set.begin(), cutlass_act_set.end()); #endif } @@ -198,8 +203,8 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const { auto* filter_var = scope->FindLocalVar(conv_filter->Name()); auto* filter_tensor = filter_var->GetMutable(); CHECK_EQ(filter_tensor->dims().size() == 4UL, true); - // when this conv2d_fusion problem size is not supported by cutlass and not - // supported by cuDNN, we should not apply this pass + // When this conv2d_fusion problem size is not supported by cutlass and not + // supported by cuDNN, we should not apply this pass. int oc = filter_tensor->dims()[0]; int ic = filter_tensor->dims()[1]; bool cutlass_can_fuse = oc % CUTLASS_NHWC_ALIGNMENT == 0 && diff --git a/paddle/fluid/framework/ir/silu_fuse_pass.cc b/paddle/fluid/framework/ir/silu_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..05817968b45c6fccdf0733ee67d7fa881c7ee99c --- /dev/null +++ b/paddle/fluid/framework/ir/silu_fuse_pass.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/silu_fuse_pass.h" +#include +#include "paddle/fluid/framework/op_version_registry.h" + +namespace paddle { +namespace framework { +namespace ir { + +void SiluFusePass::ApplyImpl(ir::Graph* graph) const { + // This pass is used for cutlass, because cutlass can fuse conv + bias + silu + bool cutlass_enable = Get("use_cutlass"); + if (!cutlass_enable) { + return; + } + + const std::string pattern_name = "silu_fuse"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + + auto* sigmoid_in = gpd.mutable_pattern()->NewNode("sigmoid_in"); + auto sigmoid_op = + gpd.mutable_pattern()->NewNode("sigmoid_op")->assert_is_op("sigmoid"); + auto sigmoid_out = gpd.mutable_pattern() + ->NewNode("sigmoid_out") + ->assert_is_op_output("sigmoid") + ->AsIntermediate(); + auto elementwise_mul_op = gpd.mutable_pattern() + ->NewNode("elementwise_mul_op") + ->assert_is_op("elementwise_mul"); + + auto elementwise_mul_out = gpd.mutable_pattern() + ->NewNode("elementwise_mul_out") + ->assert_is_op_output("elementwise_mul") + ->AsOutput(); + + sigmoid_op->LinksFrom({sigmoid_in}).LinksTo({sigmoid_out}); + elementwise_mul_op->LinksFrom({sigmoid_in, sigmoid_out}) + .LinksTo({elementwise_mul_out}); + + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + Node* sigmoid_in_node = subgraph.at(sigmoid_in); + Node* sigmoid_op_node = subgraph.at(sigmoid_op); + Node* elementwise_mul_op_node = subgraph.at(elementwise_mul_op); + Node* elementwise_mul_out_node = subgraph.at(elementwise_mul_out); + + OpDesc new_desc; + new_desc.SetType("swish"); + new_desc.SetAttr("beta", 1.f); + new_desc.SetInput("X", {sigmoid_in_node->Name()}); + new_desc.SetOutput("Out", {elementwise_mul_out_node->Name()}); + new_desc.Flush(); + + std::unordered_set del_node_set; + del_node_set.insert(sigmoid_op_node); + del_node_set.insert(elementwise_mul_op_node); + GraphSafeRemoveNodes(graph, del_node_set); + + auto fused_node = graph->CreateOpNode(&new_desc); + IR_NODE_LINK_TO(sigmoid_in_node, fused_node); + IR_NODE_LINK_TO(fused_node, elementwise_mul_out_node); + }; + gpd(graph, handler); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(silu_fuse_pass, paddle::framework::ir::SiluFusePass); diff --git a/paddle/fluid/framework/ir/silu_fuse_pass.h b/paddle/fluid/framework/ir/silu_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..6098c6c9b0bcebb49ca92cfdbe3bd62f50653f34 --- /dev/null +++ b/paddle/fluid/framework/ir/silu_fuse_pass.h @@ -0,0 +1,35 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +class SiluFusePass : public FusePassBase { + public: + virtual ~SiluFusePass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/inference/analysis/argument.h b/paddle/fluid/inference/analysis/argument.h index 002eb29b776ea083534f4db85c7ad8e2813356cd..f8a4df0617190c539dbf863e92e2bbe331dbdd43 100644 --- a/paddle/fluid/inference/analysis/argument.h +++ b/paddle/fluid/inference/analysis/argument.h @@ -202,6 +202,7 @@ struct Argument { // Passed from config. DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool); + DECL_ARGUMENT_FIELD(use_cutlass, UseCutlass, bool); DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool); DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int); diff --git a/paddle/fluid/inference/analysis/ir_pass_manager.cc b/paddle/fluid/inference/analysis/ir_pass_manager.cc index c184d94ba7fdf7b9890313a4f8c068c9066483fd..ed82dfbaa04e7d95b3306c9e4fe37ead37bba1f4 100644 --- a/paddle/fluid/inference/analysis/ir_pass_manager.cc +++ b/paddle/fluid/inference/analysis/ir_pass_manager.cc @@ -52,6 +52,7 @@ void IRPassManager::CreatePasses(Argument *argument, for (const std::string &pass_name : passes) { auto pass = framework::ir::PassRegistry::Instance().Get(pass_name); pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen())); + pass->Set("use_cutlass", new bool(argument->use_cutlass())); pass->Set("with_interleaved", new bool(argument->tensorrt_with_interleaved())); pass->Set("tensorrt_transformer_posid", @@ -80,6 +81,10 @@ void IRPassManager::CreatePasses(Argument *argument, pass->Set("optim_shape_tensor", new std::map>()); + // This gpu_device_id is used by some fp16 precision passes, so move it + // here. + pass->Set("gpu_device_id", new int(argument->gpu_device_id())); + // tuned trt dynamic_shape pass->Set("trt_tuned_dynamic_shape", new bool(argument->tensorrt_tuned_dynamic_shape())); @@ -198,7 +203,6 @@ void IRPassManager::CreatePasses(Argument *argument, "model_opt_cache_dir", new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir))); } - pass->Set("gpu_device_id", new int(argument->gpu_device_id())); pass->Set("use_static_engine", new bool(use_static_engine)); pass->Set("model_from_memory", new bool(argument->model_from_memory())); pass->Set("use_inspector", new bool(argument->tensorrt_use_inspector())); diff --git a/paddle/fluid/inference/api/analysis_config.cc b/paddle/fluid/inference/api/analysis_config.cc index 246cfc44e81dc4b7e2d556cceb7901f07bffb0b6..5d71c7cee1d4356b3475eca7a5187175b3b16165 100644 --- a/paddle/fluid/inference/api/analysis_config.cc +++ b/paddle/fluid/inference/api/analysis_config.cc @@ -115,6 +115,17 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb, Update(); } +void AnalysisConfig::Exp_EnableUseCutlass() { +#if defined(PADDLE_WITH_CUTLASS) + use_cutlass_ = true; +#else + LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()"; + use_cutlass_ = false; +#endif + + Update(); +} + void AnalysisConfig::SetExecStream(void *stream) { PADDLE_ENFORCE_NOT_NULL( stream, @@ -389,6 +400,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) { CP_MEMBER(use_fc_padding_); // GPU related. CP_MEMBER(use_gpu_); + CP_MEMBER(use_cutlass_); CP_MEMBER(use_external_stream_); CP_MEMBER(exec_stream_); CP_MEMBER(use_cudnn_); @@ -1249,6 +1261,7 @@ std::string AnalysisConfig::Summary() { // gpu info os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"}); if (use_gpu_) { + os.InsertRow({"use_cutlass", use_cutlass_ ? "true" : "false"}); os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)}); os.InsertRow({"enable_gpu_mixed", std::to_string(enable_gpu_mixed_)}); os.InsertRow({"memory_pool_init_size", diff --git a/paddle/fluid/inference/api/analysis_predictor.cc b/paddle/fluid/inference/api/analysis_predictor.cc index 46ec559939e8eecca75323e9f55935a1d795c7eb..0fb11279ebdf9cb78b316acfcaa2e08d73048b6b 100644 --- a/paddle/fluid/inference/api/analysis_predictor.cc +++ b/paddle/fluid/inference/api/analysis_predictor.cc @@ -1088,6 +1088,7 @@ void AnalysisPredictor::PrepareArgument() { // Init std::unique_ptr argument_. argument_.reset(new Argument); argument_->SetUseGPU(config_.use_gpu()); + argument_->SetUseCutlass(config_.use_cutlass_); argument_->SetUseFcPadding(config_.use_fc_padding()); argument_->SetGPUDeviceId(config_.gpu_device_id()); argument_->SetEnableIrOptim(config_.enable_ir_optim_); diff --git a/paddle/fluid/inference/api/paddle_analysis_config.h b/paddle/fluid/inference/api/paddle_analysis_config.h index 41eea1fb98c319b4a70e2a961194df55fee4f35d..0adeaf356de0ac2a131de1e8845a2e6d66a0b44b 100644 --- a/paddle/fluid/inference/api/paddle_analysis_config.h +++ b/paddle/fluid/inference/api/paddle_analysis_config.h @@ -395,6 +395,12 @@ struct PD_INFER_DECL AnalysisConfig { /// bool use_gpu() const { return use_gpu_; } /// + /// \brief When running the fp16 model on Nvidia GPU, you can also try running + /// your model on cutlass. + /// + void Exp_EnableUseCutlass(); + /// + /// /// \brief A boolean state telling whether the XPU is turned on. /// /// \return bool Whether the XPU is turned on. @@ -1047,6 +1053,7 @@ struct PD_INFER_DECL AnalysisConfig { // GPU related. bool use_gpu_{false}; + bool use_cutlass_{false}; int gpu_device_id_{0}; uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB. bool enable_gpu_mixed_{false}; diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 46eca6df552c6fd7705a2c1e8a70d75a28c6d8e7..b4018d883a028d11b116e8d33d9d846eafff807e 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -164,6 +164,7 @@ const std::vector kLiteSubgraphPasses({ const std::vector kGpuLowerPrecisionPasses{ "identity_scale_op_clean_pass", "simplify_with_basic_ops_pass", + "silu_fuse_pass", "delete_quant_dequant_linear_op_pass", "delete_weight_dequant_linear_op_pass", "map_depthwise_conv_to_conv_pass", @@ -172,6 +173,7 @@ const std::vector kGpuLowerPrecisionPasses{ "conv_elementwise_add_act_fuse_pass", "conv_elementwise_add2_act_fuse_pass", "conv_elementwise_add_fuse_pass", + "conv2d_fusion_layout_transfer_pass", "multihead_matmul_fuse_pass_v2", "fused_multi_transformer_encoder_pass", "fused_multi_transformer_decoder_pass", @@ -216,6 +218,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { "delete_weight_dequant_linear_op_pass", // "map_depthwise_conv_to_conv_pass", // "constant_folding_pass", // + "silu_fuse_pass", // "conv_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", // @@ -250,7 +253,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { #endif // "transpose_flatten_concat_fuse_pass", // "constant_folding_pass", // - "auto_mixed_precision_pass", // + "conv2d_fusion_layout_transfer_pass", // + "auto_mixed_precision_pass" }); use_gpu_ = true; diff --git a/paddle/fluid/operators/fused/conv_fusion_op.cc b/paddle/fluid/operators/fused/conv_fusion_op.cc index 27440c9408baac4a2d999cf8785a395585d16047..022c21a205dd4ace957c330afbb17fc6378d278f 100644 --- a/paddle/fluid/operators/fused/conv_fusion_op.cc +++ b/paddle/fluid/operators/fused/conv_fusion_op.cc @@ -330,3 +330,13 @@ REGISTER_OPERATOR( ops::ConvOpInferVarType, paddle::framework::EmptyGradOpMaker, paddle::framework::EmptyGradOpMaker); + +// This op is used by cutlass, conv2d_fusion_cutlass is a intermediate op +// produced by conv2d_fusion_layout_transfer_pass. +REGISTER_OPERATOR( + conv2d_fusion_cutlass, + ops::Conv2DFusionOp, + ops::Conv2DFusionOpMaker, + ops::ConvOpInferVarType, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/pybind/inference_api.cc b/paddle/fluid/pybind/inference_api.cc index 9a791e4f2e36243931216b409e03d83de8e26865..d314a9a7835190643b165ae287a52531d87b4b9d 100644 --- a/paddle/fluid/pybind/inference_api.cc +++ b/paddle/fluid/pybind/inference_api.cc @@ -646,6 +646,7 @@ void BindAnalysisConfig(py::module *m) { py::arg("memory_pool_init_size_mb"), py::arg("device_id") = 0, py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32) + .def("exp_enable_use_cutlass", &AnalysisConfig::Exp_EnableUseCutlass) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) .def("set_exec_stream", [](AnalysisConfig &self, phi::CUDAStream &stream) { diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index e12c5f10fd1c4cb5a0da65d044f606e4af9f709b..25bbd17c4feab2dbc4a57b6578ebc3662ab83fcc 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -106,8 +106,7 @@ file( "fusion/gpu/*.cu") if(WITH_CUTLASS) - file(GLOB cutlass_cu "fusion/cutlass/default_moe_fc_traits.h" - "fusion/cutlass/linear_combination_ft_gelu.h" "fusion/cutlass/moe*") + file(GLOB cutlass_cu "fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu") list(APPEND kernel_cu ${cutlass_cu}) endif() diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias.cu new file mode 100644 index 0000000000000000000000000000000000000000..308fd276c12be527d8fb21078eb6e95ba2ee4e6b --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias.cu @@ -0,0 +1,225 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" +#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { + +template +cutlass::Status Conv2dBiasImpl(ConvAllParams params) { + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + using MMAOp = cutlass::arch::OpClassTensorOp; + using SmArch = cutlass::arch::Sm75; + using ThreadblockShape = TShape; + using WarpShape = WShape; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using SwizzleThreadBlock = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + constexpr int NumStages = 2; + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kOptimized; + using EpilogueOp = + cutlass::epilogue::thread::LinearCombination; + + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kStrided, + Alignment, + Alignment>::Kernel; + using ImplicitGemm = + cutlass::conv::device::ImplicitGemmConvolution; + + const half *input = params.input; + const half *weight = params.weight; + const half *bias = params.bias; + half *output = params.output; + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w0; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + + int oh = params.oh; + int ow = params.ow; + int dilation_h = params.dilation_h; + int dilation_w = params.dilation_w; + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic}, + {oc, kh, kw, ic}, + {pad_h0, 0, pad_w0, 0}, + {stride_h, stride_w}, + {dilation_h, dilation_w}, + {batch, oh, ow, oc}, + mode, + 1); + + typename ImplicitGemm::Arguments arguments{ + problem_size, + {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}}, + {(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}}, + {(cutlass::half_t *)(bias), {0, 0, 0}}, + {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}}, + {1.f, 1.f}}; + + ImplicitGemm implicit_gemm_op; + size_t bytes = implicit_gemm_op.get_workspace_size(arguments); + + auto ctx = params.ctx; + auto stream = ctx->stream(); + paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + paddle::memory::Alloc( + ctx->GetPlace(), + bytes, + phi::Stream(reinterpret_cast(stream))); + void *workspace = tmp_gpu_ptrs_data->ptr(); + + cutlass::Status status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + status = implicit_gemm_op.initialize(arguments, workspace); + CUTLASS_CHECK(status); + status = implicit_gemm_op(stream); + CUTLASS_CHECK(status); + return status; +} + +// config 0 +template cutlass::Status Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 32, 64>>( + ConvAllParams); +// config 1 +template cutlass::Status Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 32, 64>>( + ConvAllParams); +// config 2 +template cutlass::Status Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 32, 64>>( + ConvAllParams); +// config 3 +template cutlass::Status Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 32, 64>>( + ConvAllParams); +// config 4 +template cutlass::Status Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 32, 32>>( + ConvAllParams); +// config 5 +template cutlass::Status Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 64, 32>>( + ConvAllParams); +// config 6 +template cutlass::Status Conv2dBiasImpl, + cutlass::gemm::GemmShape<64, 64, 32>>( + ConvAllParams); +// config 7 +template cutlass::Status Conv2dBiasImpl, + cutlass::gemm::GemmShape<64, 64, 32>>( + ConvAllParams); +// config 8 +template cutlass::Status Conv2dBiasImpl, + cutlass::gemm::GemmShape<64, 32, 32>>( + ConvAllParams); + +std::vector> + conv2d_bias_all_func = { + Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 32, 32>>, + Conv2dBiasImpl, + cutlass::gemm::GemmShape<32, 64, 32>>, + Conv2dBiasImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasImpl, + cutlass::gemm::GemmShape<64, 32, 32>>}; + +std::map, int> map_problem_conv2d_bias; +std::mutex conv2d_bias_mutex; + +void Conv2dBias(ConvAllParams params) { + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w0; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + + std::vector problem_size = { + batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w}; + + if (map_problem_conv2d_bias.count(problem_size)) { + conv2d_bias_all_func[map_problem_conv2d_bias.at(problem_size)](params); + return; + } + + int best_config_index = + ProfileToGetBestConfig(conv2d_bias_all_func, params, CONV2D_BIAS); + + std::lock_guard guard(conv2d_bias_mutex); + map_problem_conv2d_bias[problem_size] = best_config_index; + conv2d_bias_all_func[best_config_index](params); +} +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_add_relu.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_add_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..3fac4f5673b7f0bb09b3d1afca4213610227a1c0 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_add_relu.cu @@ -0,0 +1,248 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" +#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" + +namespace phi { +namespace fusion { + +namespace cutlass_internal { + +template +cutlass::Status Conv2dBiasAddReluImpl(ConvAllParams params) { + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< + cutlass::half_t, + float, + float, + cutlass::half_t, + Alignment, + cutlass::epilogue::thread::Identity, + cutlass::plus, + cutlass::epilogue::thread::ReLu>; + + using Conv2dFpropKernel = + typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast< + cutlass::half_t, + cutlass::layout::TensorNHWC, + cutlass::half_t, + cutlass::layout::TensorNHWC, + cutlass::half_t, + cutlass::layout::TensorNHWC, + float, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm75, + TShape, + WShape, + cutlass::gemm::GemmShape<16, 8, 8>, + EpilogueOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>, + 2, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::IteratorAlgorithm::kOptimized, + cutlass::conv::StrideSupport::kStrided, + Alignment, + Alignment>::Kernel; + + using ImplicitGemm = + cutlass::conv::device::ImplicitGemmConvolution; + + const half *input = params.input; + const half *weight = params.weight; + const half *bias = params.bias; + + half *output = params.output; + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w0; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + const half *residual = params.residual; + + int oh = params.oh; + int ow = params.ow; + int dilation_h = params.dilation_h; + int dilation_w = params.dilation_w; + + cutlass::conv::Conv2dProblemSize problem_size( + {batch, ih, iw, ic}, + {oc, kh, kw, ic}, + {pad_h0, 0, pad_w0, 0}, + {stride_h, stride_w}, + {dilation_h, dilation_w}, + {batch, oh, ow, oc}, + cutlass::conv::Mode::kCrossCorrelation, + 1); + + typename ImplicitGemm::Arguments arguments{ + problem_size, + {(cutlass::half_t *)input, {ic, ic * iw, ic * iw * ih}}, + {(cutlass::half_t *)weight, {ic, ic * kw, ic * kw * kh}}, + {(cutlass::half_t *)residual, {oc, oc * ow, oc * ow * oh}}, + {(cutlass::half_t *)output, {oc, oc * ow, oc * ow * oh}}, + {1.f, 1.f}, + cutlass::conv::SplitKMode::kSerial, + (cutlass::half_t *)(bias), + nullptr, + 0, + oc}; + + ImplicitGemm implicit_gemm_op; + size_t bytes = implicit_gemm_op.get_workspace_size(arguments); + + auto ctx = params.ctx; + auto stream = ctx->stream(); + paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + paddle::memory::Alloc( + ctx->GetPlace(), + bytes, + phi::Stream(reinterpret_cast(stream))); + void *workspace = tmp_gpu_ptrs_data->ptr(); + + cutlass::Status status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + status = implicit_gemm_op.initialize(arguments, workspace); + CUTLASS_CHECK(status); + status = implicit_gemm_op(stream); + CUTLASS_CHECK(status); + return status; +} + +// config 0 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 1 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 2 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 3 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 4 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams); +// config 5 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams); +// config 6 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 7 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 8 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams); +// config 9 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 10 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 11 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 12 +template cutlass::Status + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); + +std::vector> + conv2d_bias_add_relu_all_func = { + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 32, 32>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<32, 64, 32>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 32, 32>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasAddReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>}; +std::map, int> map_problem_conv2d_bias_add_relu; +std::mutex conv2d_bias_add_relu_mutex; + +void Conv2dBiasAddRelu(ConvAllParams params) { + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w0; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + + std::vector problem_size = { + batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w}; + + if (map_problem_conv2d_bias_add_relu.count(problem_size)) { + conv2d_bias_add_relu_all_func[map_problem_conv2d_bias_add_relu.at( + problem_size)](params); + return; + } + + std::lock_guard guard(conv2d_bias_add_relu_mutex); + + // config 6's diff is large. + conv2d_bias_add_relu_all_func[6] = nullptr; + + int best_config_index = ProfileToGetBestConfig( + conv2d_bias_add_relu_all_func, params, CONV2D_BIAS_ADD_RELU); + map_problem_conv2d_bias_add_relu[problem_size] = best_config_index; + conv2d_bias_add_relu_all_func[best_config_index](params); +} +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_leaky_relu.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_leaky_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..97ca75e477644ccddf83d4dd24b3b1b98cc04769 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_leaky_relu.cu @@ -0,0 +1,226 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/epilogue/thread/linear_combination_leaky_relu.h" +#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { +template +cutlass::Status Conv2dBiasLeakyReluImpl(ConvAllParams params) { + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + using MMAOp = cutlass::arch::OpClassTensorOp; + using SmArch = cutlass::arch::Sm75; + using ThreadblockShape = TShape; + using WarpShape = WShape; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using SwizzleThreadBlock = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + constexpr int NumStages = 2; + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kOptimized; + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationLeakyRelu< + ElementOutput, + Alignment, + float, + ElementComputeEpilogue>; + + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kStrided, + Alignment, + Alignment>::Kernel; + using ImplicitGemm = + cutlass::conv::device::ImplicitGemmConvolution; + + const half *input = params.input; + const half *weight = params.weight; + const half *bias = params.bias; + half *output = params.output; + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w0; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + float alpha = params.alpha; + + int oh = params.oh; + int ow = params.ow; + int dilation_h = params.dilation_h; + int dilation_w = params.dilation_w; + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic}, + {oc, kh, kw, ic}, + {pad_h0, 0, pad_w0, 0}, + {stride_h, stride_w}, + {dilation_h, dilation_w}, + {batch, oh, ow, oc}, + mode, + 1); + + typename ImplicitGemm::Arguments arguments{ + problem_size, + {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}}, + {(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}}, + {(cutlass::half_t *)(bias), {0, 0, 0}}, + {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}}, + {1.f, 1.f, alpha}}; + + ImplicitGemm implicit_gemm_op; + size_t bytes = implicit_gemm_op.get_workspace_size(arguments); + + auto ctx = params.ctx; + auto stream = ctx->stream(); + paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + paddle::memory::Alloc( + ctx->GetPlace(), + bytes, + phi::Stream(reinterpret_cast(stream))); + void *workspace = tmp_gpu_ptrs_data->ptr(); + + cutlass::Status status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + status = implicit_gemm_op.initialize(arguments, workspace); + CUTLASS_CHECK(status); + status = implicit_gemm_op(stream); + CUTLASS_CHECK(status); + return status; +} + +// config 0 +template cutlass::Status Conv2dBiasLeakyReluImpl< + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 1 +template cutlass::Status Conv2dBiasLeakyReluImpl< + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 2 +template cutlass::Status Conv2dBiasLeakyReluImpl< + cutlass::gemm::GemmShape<128, 32, 64>, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 3 +template cutlass::Status Conv2dBiasLeakyReluImpl< + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 4 +template cutlass::Status Conv2dBiasLeakyReluImpl< + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams); +// config 5 +template cutlass::Status Conv2dBiasLeakyReluImpl< + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams); +// config 6 +template cutlass::Status Conv2dBiasLeakyReluImpl< + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 7 +template cutlass::Status Conv2dBiasLeakyReluImpl< + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 8 +template cutlass::Status Conv2dBiasLeakyReluImpl< + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams); + +std::vector> + conv2d_bias_leaky_relu_all_func = { + Conv2dBiasLeakyReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasLeakyReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasLeakyReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasLeakyReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasLeakyReluImpl, + cutlass::gemm::GemmShape<32, 32, 32>>, + Conv2dBiasLeakyReluImpl, + cutlass::gemm::GemmShape<32, 64, 32>>, + Conv2dBiasLeakyReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasLeakyReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasLeakyReluImpl, + cutlass::gemm::GemmShape<64, 32, 32>>}; + +std::map, int> map_problem_conv2d_bias_leaky_relu; +std::mutex conv2d_bias_leaky_relu_mutex; + +void Conv2dBiasLeakyRelu(ConvAllParams params) { + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w0; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + + std::vector problem_size = { + batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w}; + + if (map_problem_conv2d_bias_leaky_relu.count(problem_size)) { + conv2d_bias_leaky_relu_all_func[map_problem_conv2d_bias_leaky_relu.at( + problem_size)](params); + return; + } + + int best_config_index = ProfileToGetBestConfig( + conv2d_bias_leaky_relu_all_func, params, CONV2D_BIAS_LEAKY_RELU); + + std::lock_guard guard(conv2d_bias_leaky_relu_mutex); + map_problem_conv2d_bias_leaky_relu[problem_size] = best_config_index; + conv2d_bias_leaky_relu_all_func[best_config_index](params); +} +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_relu.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_relu.cu new file mode 100644 index 0000000000000000000000000000000000000000..a5f5a9bee12c644b12d5b493407e17b7ccaef6e1 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_relu.cu @@ -0,0 +1,225 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" +#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { +template +cutlass::Status Conv2dBiasReluImpl(ConvAllParams params) { + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + using MMAOp = cutlass::arch::OpClassTensorOp; + using SmArch = cutlass::arch::Sm75; + using ThreadblockShape = TShape; + using WarpShape = WShape; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using SwizzleThreadBlock = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + constexpr int NumStages = 2; + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kOptimized; + using EpilogueOp = + cutlass::epilogue::thread::LinearCombinationRelu; + + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kStrided, + Alignment, + Alignment>::Kernel; + using ImplicitGemm = + cutlass::conv::device::ImplicitGemmConvolution; + + const half *input = params.input; + const half *weight = params.weight; + const half *bias = params.bias; + half *output = params.output; + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w0; + + int stride_h = params.stride_h; + int stride_w = params.stride_w; + int oh = params.oh; + int ow = params.ow; + int dilation_h = params.dilation_h; + int dilation_w = params.dilation_w; + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic}, + {oc, kh, kw, ic}, + {pad_h0, 0, pad_w0, 0}, + {stride_h, stride_w}, + {dilation_h, dilation_w}, + {batch, oh, ow, oc}, + mode, + 1); + + typename ImplicitGemm::Arguments arguments{ + problem_size, + {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}}, + {(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}}, + {(cutlass::half_t *)(bias), {0, 0, 0}}, + {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}}, + {1.f, 1.f}}; + + ImplicitGemm implicit_gemm_op; + size_t bytes = implicit_gemm_op.get_workspace_size(arguments); + + auto ctx = params.ctx; + auto stream = ctx->stream(); + paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + paddle::memory::Alloc( + ctx->GetPlace(), + bytes, + phi::Stream(reinterpret_cast(stream))); + void *workspace = tmp_gpu_ptrs_data->ptr(); + + cutlass::Status status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + status = implicit_gemm_op.initialize(arguments, workspace); + CUTLASS_CHECK(status); + status = implicit_gemm_op(stream); + CUTLASS_CHECK(status); + return status; +} + +// config 0 +template cutlass::Status + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 1 +template cutlass::Status + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 2 +template cutlass::Status + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 3 +template cutlass::Status + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 4 +template cutlass::Status + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams); +// config 5 +template cutlass::Status + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams); +// config 6 +template cutlass::Status + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 7 +template cutlass::Status + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 8 +template cutlass::Status + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams); + +std::vector> + conv2d_bias_relu_all_func = { + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 32, 32>>, + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<32, 64, 32>>, + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasReluImpl, + cutlass::gemm::GemmShape<64, 32, 32>>}; +std::map, int> map_problem_conv2d_bias_relu; +std::mutex conv2d_bias_relu_mutex; + +void Conv2dBiasRelu(ConvAllParams params) { + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w0; + + int stride_h = params.stride_h; + int stride_w = params.stride_w; + + std::vector problem_size = { + batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w}; + + if (map_problem_conv2d_bias_relu.count(problem_size)) { + conv2d_bias_relu_all_func[map_problem_conv2d_bias_relu.at(problem_size)]( + params); + return; + } + + int best_config_index = ProfileToGetBestConfig( + conv2d_bias_relu_all_func, params, CONV2D_BIAS_RELU); + + std::lock_guard guard(conv2d_bias_relu_mutex); + map_problem_conv2d_bias_relu[problem_size] = best_config_index; + conv2d_bias_relu_all_func[best_config_index](params); +} +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_relu_few_channels.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_relu_few_channels.cu new file mode 100644 index 0000000000000000000000000000000000000000..1acd191033529eb3b0aff8c616fec21be59f5265 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_relu_few_channels.cu @@ -0,0 +1,218 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/epilogue/thread/linear_combination_bias_relu.h" +#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { +template +cutlass::Status Conv2dBiasReluFewChannelsImpl(ConvAllParams params) { + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + using MMAOp = cutlass::arch::OpClassTensorOp; + using SmArch = cutlass::arch::Sm75; + using ThreadblockShape = TShape; + using WarpShape = WShape; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using SwizzleThreadBlock = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + constexpr int NumStages = 2; + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kFewChannels; + using EpilogueOp = + cutlass::epilogue::thread::LinearCombinationRelu; + + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kStrided, + Alignment, + Alignment>::Kernel; + using ImplicitGemm = + cutlass::conv::device::ImplicitGemmConvolution; + + const half *input = params.input; + const half *weight = params.weight; + const half *bias = params.bias; + half *output = params.output; + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w1; + + int stride_h = params.stride_h; + int stride_w = params.stride_w; + + int oh = params.oh; + int ow = params.ow; + int dilation_h = params.dilation_h; + int dilation_w = params.dilation_w; + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic}, + {oc, kh, kw, ic}, + {pad_h0, 0, pad_w0, 0}, + {stride_h, stride_w}, + {dilation_h, dilation_w}, + {batch, oh, ow, oc}, + mode, + 1); + + typename ImplicitGemm::Arguments arguments{ + problem_size, + {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}}, + {(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}}, + {(cutlass::half_t *)(bias), {0, 0, 0}}, + {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}}, + {1.f, 1.f}}; + + ImplicitGemm implicit_gemm_op; + size_t bytes = implicit_gemm_op.get_workspace_size(arguments); + + auto ctx = params.ctx; + auto stream = ctx->stream(); + paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + paddle::memory::Alloc( + ctx->GetPlace(), + bytes, + phi::Stream(reinterpret_cast(stream))); + void *workspace = tmp_gpu_ptrs_data->ptr(); + + cutlass::Status status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + status = implicit_gemm_op.initialize(arguments, workspace); + CUTLASS_CHECK(status); + status = implicit_gemm_op(stream); + CUTLASS_CHECK(status); + return status; +} + +// config 0 +template cutlass::Status Conv2dBiasReluFewChannelsImpl< + cutlass::gemm::GemmShape<64, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 1 +template cutlass::Status Conv2dBiasReluFewChannelsImpl< + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 2 +template cutlass::Status Conv2dBiasReluFewChannelsImpl< + cutlass::gemm::GemmShape<128, 32, 64>, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 3 +template cutlass::Status Conv2dBiasReluFewChannelsImpl< + cutlass::gemm::GemmShape<128, 64, 64>, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 4 +template cutlass::Status Conv2dBiasReluFewChannelsImpl< + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams); +// config 5 +template cutlass::Status Conv2dBiasReluFewChannelsImpl< + cutlass::gemm::GemmShape<64, 128, 32>, + cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams); +// config 6 +template cutlass::Status Conv2dBiasReluFewChannelsImpl< + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 7 +template cutlass::Status Conv2dBiasReluFewChannelsImpl< + cutlass::gemm::GemmShape<64, 256, 32>, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 8 +template cutlass::Status Conv2dBiasReluFewChannelsImpl< + cutlass::gemm::GemmShape<128, 64, 32>, + cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams); + +std::vector> + conv2d_bias_relu_few_channels_all_func = { + Conv2dBiasReluFewChannelsImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasReluFewChannelsImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasReluFewChannelsImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasReluFewChannelsImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasReluFewChannelsImpl, + cutlass::gemm::GemmShape<32, 32, 32>>, + Conv2dBiasReluFewChannelsImpl, + cutlass::gemm::GemmShape<32, 64, 32>>, + Conv2dBiasReluFewChannelsImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasReluFewChannelsImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasReluFewChannelsImpl, + cutlass::gemm::GemmShape<64, 32, 32>>}; +std::map, int> map_problem_conv2d_bias_relu_few_channels; + +void Conv2dBiasReluFewChannels(ConvAllParams params) { + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w1; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + + std::vector problem_size = { + batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w}; + + if (map_problem_conv2d_bias_relu_few_channels.count(problem_size)) { + conv2d_bias_relu_few_channels_all_func + [map_problem_conv2d_bias_relu_few_channels.at(problem_size)](params); + return; + } + // +} +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_silu.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_silu.cu new file mode 100644 index 0000000000000000000000000000000000000000..469585ccf8398b9e49d12eb2accc1481eb9b84af --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_bias_silu.cu @@ -0,0 +1,226 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/epilogue/thread/linear_combination_silu.h" +#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { +template +cutlass::Status Conv2dBiasSiluImpl(ConvAllParams params) { + using ElementAccumulator = float; + using ElementComputeEpilogue = float; + using ElementInputA = cutlass::half_t; + using ElementInputB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using LayoutInputA = cutlass::layout::TensorNHWC; + using LayoutInputB = cutlass::layout::TensorNHWC; + using LayoutOutput = cutlass::layout::TensorNHWC; + using MMAOp = cutlass::arch::OpClassTensorOp; + using SmArch = cutlass::arch::Sm75; + using ThreadblockShape = TShape; + using WarpShape = WShape; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using SwizzleThreadBlock = + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>; + constexpr int NumStages = 2; + static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kOptimized; + using EpilogueOp = + cutlass::epilogue::thread::LinearCombinationSilu; + + using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kStrided, + Alignment, + Alignment>::Kernel; + using ImplicitGemm = + cutlass::conv::device::ImplicitGemmConvolution; + + const half *input = params.input; + const half *weight = params.weight; + const half *bias = params.bias; + half *output = params.output; + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w0; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + + int oh = params.oh; + int ow = params.ow; + int dilation_h = params.dilation_h; + int dilation_w = params.dilation_w; + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic}, + {oc, kh, kw, ic}, + {pad_h0, 0, pad_w0, 0}, + {stride_h, stride_w}, + {dilation_h, dilation_w}, + {batch, oh, ow, oc}, + mode, + 1); + + typename ImplicitGemm::Arguments arguments{ + problem_size, + {(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}}, + {(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}}, + {(cutlass::half_t *)(bias), {0, 0, 0}}, + {(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}}, + {1.f, 1.f}}; + + ImplicitGemm implicit_gemm_op; + size_t bytes = implicit_gemm_op.get_workspace_size(arguments); + + auto ctx = params.ctx; + auto stream = ctx->stream(); + paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data = + paddle::memory::Alloc( + ctx->GetPlace(), + bytes, + phi::Stream(reinterpret_cast(stream))); + void *workspace = tmp_gpu_ptrs_data->ptr(); + + cutlass::Status status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + status = implicit_gemm_op.initialize(arguments, workspace); + CUTLASS_CHECK(status); + status = implicit_gemm_op(stream); + CUTLASS_CHECK(status); + return status; +} + +// config 0 +template cutlass::Status + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 1 +template cutlass::Status + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 2 +template cutlass::Status + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 3 +template cutlass::Status + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams); +// config 4 +template cutlass::Status + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams); +// config 5 +template cutlass::Status + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams); +// config 6 +template cutlass::Status + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 7 +template cutlass::Status + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams); +// config 8 +template cutlass::Status + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams); + +std::vector> + conv2d_bias_silu_all_func = { + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 32, 64>>, + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 32, 32>>, + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<32, 64, 32>>, + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<64, 64, 32>>, + Conv2dBiasSiluImpl, + cutlass::gemm::GemmShape<64, 32, 32>>}; + +std::map, int> map_problem_conv2d_bias_silu; +std::mutex conv2d_bias_silu_mutex; + +void Conv2dBiasSilu(ConvAllParams params) { + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h0 = params.pad_h0; + int pad_w0 = params.pad_w0; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + + std::vector problem_size = { + batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w}; + + if (map_problem_conv2d_bias_silu.count(problem_size)) { + conv2d_bias_silu_all_func[map_problem_conv2d_bias_silu.at(problem_size)]( + params); + return; + } + + int best_config_index = ProfileToGetBestConfig( + conv2d_bias_silu_all_func, params, CONV2D_BIAS_SILU); + + std::lock_guard guard(conv2d_bias_silu_mutex); + + map_problem_conv2d_bias_silu[problem_size] = best_config_index; + conv2d_bias_silu_all_func[best_config_index](params); +} +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h new file mode 100644 index 0000000000000000000000000000000000000000..b740d49fc1dc3fb0e021e80f436e783c5a392aea --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once +#include +#include +#include +#include + +#include "paddle/phi/backends/gpu/gpu_context.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { + +typedef struct { + const half *input; + const half *weight; + const half *bias; + const half *residual; + half *output; + int batch; + int ic; + int ih; + int iw; + int kh; + int kw; + int oc; + int pad_h0; + int pad_h1; + int pad_w0; + int pad_w1; + int stride_h; + int stride_w; + int dilation_h; + int dilation_w; + int oh; + int ow; + const phi::GPUContext *ctx; + float alpha; // for leaky_relu use +} ConvAllParams; + +// Below functions are provided by cutlass, they are called by phi. +void Conv2dBiasAddRelu(ConvAllParams params); +void Conv2dBiasRelu(ConvAllParams params); +void Conv2dBiasLeakyRelu(ConvAllParams params); +void Conv2dBiasSilu(ConvAllParams params); +void Conv2dBias(ConvAllParams params); +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu new file mode 100644 index 0000000000000000000000000000000000000000..174cb4aaa405956811c8b3203a1c08f97376be97 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.cu @@ -0,0 +1,277 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { +struct logical_coord { + int n; + int c; + int h; + int w; +}; + +float diff(const half *c, const float *c_baseline, int n) { + float max_diff = -1.; + for (int i = 0; i < n; i++) { + float c_value = __half2float(c[i]); + if (std::abs(c_baseline[i] - c_value) > max_diff) { + max_diff = std::abs(c_baseline[i] - c_value); + } + } + return max_diff; +} + +__device__ int gpu_nhwc(struct logical_coord shape, + struct logical_coord index) { + return index.n * shape.h * shape.w * shape.c + index.h * shape.w * shape.c + + index.w * shape.c + index.c; +} + +__global__ void naive_conv2d_kernel(const half *input, + const half *weight, + const half *bias, + float *output, + int batch, + int ic, + int ih, + int iw, + int kh, + int kw, + int oc, + int pad_h, + int pad_w, + int stride_h, + int stride_w, + int dilation_h, + int dilation_w, + int oh, + int ow, + const half *residual, + float alpha, // for leaky_relu + OpType op_type) { + int M = batch * oh * ow; + int N = oc; + int K = ic * kh * kw; + int m_i = threadIdx.x + blockIdx.x * blockDim.x; + int n_i = threadIdx.y + blockIdx.y * blockDim.y; + if (m_i >= M || n_i >= N) return; + + int batch_i = m_i / (oh * ow); + int oh_i = (m_i % (oh * ow)) / ow; + int ow_i = (m_i % (oh * ow)) % ow; + int oc_i = n_i; + + struct logical_coord weight_shape = {oc, ic, kh, kw}; + struct logical_coord input_shape = {batch, ic, ih, iw}; + int out_offset = m_i * N + n_i; + float *out_ptr = output + out_offset; + float sum = 0.f; + + for (int k_i = 0; k_i < K; k_i++) { + int ic_i = k_i / (kh * kw); + int kh_i = (k_i % (kh * kw)) / kw; + int kw_i = (k_i % (kh * kw)) % kw; + + struct logical_coord weight_index = {oc_i, ic_i, kh_i, kw_i}; + + int ih_i = oh_i * stride_h - pad_h + kh_i * dilation_h; + int iw_i = ow_i * stride_w - pad_w + kw_i * dilation_w; + + if (ih_i < 0 || ih_i >= ih) continue; + if (iw_i < 0 || iw_i >= iw) continue; + + struct logical_coord input_index = {batch_i, ic_i, ih_i, iw_i}; + const half *weight_ptr = weight + gpu_nhwc(weight_shape, weight_index); + const half *in_ptr = input + gpu_nhwc(input_shape, input_index); + sum += __half2float(*in_ptr) * __half2float(*weight_ptr); + } + + sum += __half2float(*(bias + oc_i)); + float x = sum; + + switch (op_type) { + case CONV2D_BIAS: + *out_ptr = x; + break; + case CONV2D_BIAS_RELU: + *out_ptr = x > 0 ? x : 0; + break; + case CONV2D_BIAS_SILU: + *out_ptr = x * (1.f / (1 + exp(-x))); + break; + case CONV2D_BIAS_ADD_RELU: + x += __half2float(*(residual + out_offset)); + *out_ptr = x > 0 ? x : 0; + break; + case CONV2D_BIAS_LEAKY_RELU: + *out_ptr = x > 0 ? x : (x * alpha); + break; + default: + break; + } +} + +float conv2d_diff_gpu(ConvAllParams params, OpType op_type) { + const half *input = params.input; + const half *weight = params.weight; + const half *bias = params.bias; + half *output = params.output; + int batch = params.batch; + int ic = params.ic; + int ih = params.ih; + int iw = params.iw; + int kh = params.kh; + int kw = params.kw; + int oc = params.oc; + int pad_h = params.pad_h0; + int pad_w = params.pad_w0; + int stride_h = params.stride_h; + int stride_w = params.stride_w; + int dilation_h = params.dilation_h; + int dilation_w = params.dilation_w; + const half *residual = params.residual; + + int oh = params.oh; + int ow = params.ow; + int M = batch * oh * ow; + int N = oc; + + constexpr int blockM = 16; + constexpr int blockN = 16; + uint3 grid = {(M + blockM - 1) / blockM, (N + blockN - 1) / blockN, 1}; + uint3 block = {blockM, blockN, 1}; + + int output_size = batch * oc * oh * ow; + half *output_from_cutlass = + reinterpret_cast(malloc(sizeof(half) * output_size)); + cudaMemcpy(output_from_cutlass, + output, + output_size * sizeof(half), + cudaMemcpyDeviceToHost); + + float *gpu_output; + cudaMalloc(&gpu_output, output_size * sizeof(float)); + naive_conv2d_kernel<<>>(input, + weight, + bias, + gpu_output, + batch, + ic, + ih, + iw, + kh, + kw, + oc, + pad_h, + pad_w, + stride_h, + stride_w, + dilation_h, + dilation_w, + oh, + ow, + residual, + params.alpha, + op_type); + float *output_from_gpu = + reinterpret_cast(malloc(sizeof(float) * output_size)); + cudaMemcpy(output_from_gpu, + gpu_output, + output_size * sizeof(float), + cudaMemcpyDeviceToHost); + float max_diff = diff(output_from_cutlass, output_from_gpu, output_size); + + free(output_from_cutlass); + free(output_from_gpu); + cudaFree(gpu_output); + return max_diff; +} + +std::string OpType2String(OpType op_type) { + switch (op_type) { + case CONV2D_BIAS: + return "conv2d_bias"; + break; + case CONV2D_BIAS_RELU: + return "conv2d_bias_relu"; + break; + case CONV2D_BIAS_SILU: + return "conv2d_bias_add_silu"; + break; + case CONV2D_BIAS_ADD_RELU: + return "conv2d_bias_add_relu"; + break; + case CONV2D_BIAS_LEAKY_RELU: + return "conv2d_bias_leaky_relu"; + default: + break; + } + return "unnamed_op"; +} + +int ProfileToGetBestConfig( + const std::vector> &all_func, + ConvAllParams params, + OpType op_type) { + constexpr int WARMUP = 10; + constexpr int REPEAT = 100; + float min_time = 100000.f; + int min_time_index = -1; + for (int i = 0; i < all_func.size(); i++) { + cutlass::Status status; + auto func = all_func[i]; + // When func has large diff, we will make it nullptr. + if (!func) continue; + + for (int ii = 0; ii < WARMUP; ii++) { + status = func(params); + } + + cudaEvent_t beg, end; + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&beg)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&end)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(beg)); + for (int ii = 0; ii < REPEAT; ii++) { + status = func(params); + } + + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(end)); + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(end)); + float elapsed_time; + PADDLE_ENFORCE_GPU_SUCCESS(cudaEventElapsedTime(&elapsed_time, beg, end)); + if (elapsed_time < min_time && status == cutlass::Status::kSuccess) { + min_time = elapsed_time; + min_time_index = i; + } + // debug code + VLOG(3) << OpType2String(op_type) << ": tactic " << i << " has max diff " + << conv2d_diff_gpu(params, op_type) << " compared with baseline."; + } + + if (min_time_index < 0) { + PADDLE_THROW( + phi::errors::NotFound("Can't find any cutlass config for this %s op.", + OpType2String(op_type).c_str())); + } + return min_time_index; +} + +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h new file mode 100644 index 0000000000000000000000000000000000000000..a5d0f83651488ee718de1e07ba2ae96b998c6c52 --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "paddle/fluid/memory/allocation/allocator.h" +#include "paddle/fluid/memory/malloc.h" +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/enforce.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { +#define CUTLASS_CHECK(status) \ + if (status != cutlass::Status::kSuccess) { \ + VLOG(3) \ + << "Cutlass can not deal with this problem size, skip this kernel!"; \ + return status; \ + } + +typedef enum { + CONV2D_BIAS, + CONV2D_BIAS_RELU, + CONV2D_BIAS_ADD_RELU, + CONV2D_BIAS_SILU, + CONV2D_BIAS_LEAKY_RELU +} OpType; + +// conv2d_diff_gpu calculate diff of cutlass output and baseline output, you can +// use them to debug. return value is the max diff between cutlass and baseline. +float conv2d_diff_gpu(ConvAllParams params, OpType op_type); + +int ProfileToGetBestConfig( + const std::vector>& all_func, + ConvAllParams params, + OpType op_type); + +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi diff --git a/paddle/phi/kernels/fusion/cutlass/conv2d_fusion.cu b/paddle/phi/kernels/fusion/cutlass/conv2d_fusion.cu new file mode 100644 index 0000000000000000000000000000000000000000..93c5581ce9db6f3b2d50c8b1872b07cc864124ba --- /dev/null +++ b/paddle/phi/kernels/fusion/cutlass/conv2d_fusion.cu @@ -0,0 +1,141 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h" + +namespace phi { +namespace fusion { +namespace cutlass_internal { +template +void Conv2dFusionKernel(const Context& ctx, + const DenseTensor& x, + const DenseTensor& filter, + const DenseTensor& bias, + const paddle::optional& residual, + const std::vector& strides, + const std::vector& paddings, + const std::string& padding_algorithm, + int groups, + const std::vector& dilations, + const std::string& data_format, + const std::string& activation, + float fuse_alpha, + DenseTensor* output) { + ctx.template Alloc(output); + auto in_dims = x.dims(); + auto filter_dims = filter.dims(); + auto out_dims = output->dims(); + CHECK_EQ(in_dims.size() == 4UL, true); + CHECK_EQ(filter_dims.size() == 4UL, true); + CHECK_EQ(strides.size() == 2UL, true); + CHECK_EQ(dilations.size() == 2UL, true); + CHECK_EQ(groups == 1, true); + CHECK_EQ(padding_algorithm == "EXPLICIT", true); + const int batch = in_dims[0]; + const int ic = in_dims[3]; + const int ih = in_dims[1]; + const int iw = in_dims[2]; + int pad_h0 = 0; + int pad_h1 = 0; + int pad_w0 = 0; + int pad_w1 = 0; + if (paddings.size() == 2UL) { + pad_h0 = paddings[0]; + pad_h1 = paddings[0]; + pad_w0 = paddings[1]; + pad_w1 = paddings[1]; + } else if (paddings.size() == 4UL) { + pad_h0 = paddings[0]; + pad_h1 = paddings[1]; + pad_w0 = paddings[2]; + pad_w1 = paddings[3]; + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Attr paddins in conv2d_fusion must have 2 or 4 elements, but now have " + "%u elements.", + paddings.size())); + } + + const int stride_h = strides[0]; + const int stride_w = strides[1]; + const int dilation_h = dilations[0]; + const int dilation_w = dilations[1]; + const int oc = filter_dims[0]; + const int kh = filter_dims[1]; + const int kw = filter_dims[2]; + + CHECK_EQ(out_dims.size() == 4UL, true); + const int oh = out_dims[1]; + const int ow = out_dims[2]; + + ConvAllParams params = {reinterpret_cast(x.data()), + reinterpret_cast(filter.data()), + reinterpret_cast(bias.data()), + nullptr, + reinterpret_cast(output->data()), + batch, + ic, + ih, + iw, + kh, + kw, + oc, + pad_h0, + pad_h1, + pad_w0, + pad_w1, + stride_h, + stride_w, + dilation_h, + dilation_w, + oh, + ow, + &ctx}; + + if (residual) { + if (activation == "relu") { + params.residual = reinterpret_cast(residual->data()); + Conv2dBiasAddRelu(params); + } else { + PADDLE_THROW(paddle::platform::errors::InvalidArgument( + "Cutlass now only support relu activation in a residual block")); + } + } else if (activation == "relu") { + Conv2dBiasRelu(params); + } else if (activation == "swish") { + Conv2dBiasSilu(params); + } else if (activation == "identity") { + Conv2dBias(params); + } else if (activation == "leaky_relu") { + params.alpha = fuse_alpha; + Conv2dBiasLeakyRelu(params); + } else { + PADDLE_THROW(phi::errors::InvalidArgument( + "Cutlass does not support this activation: %s.", activation.c_str())); + } + output->set_layout(DataLayout::NHWC); +} +} // namespace cutlass_internal +} // namespace fusion +} // namespace phi + +PD_REGISTER_KERNEL(conv2d_fusion_cutlass, + GPU, + ALL_LAYOUT, + phi::fusion::cutlass_internal::Conv2dFusionKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/ops/compat/conv2d_sig.cc b/paddle/phi/ops/compat/conv2d_sig.cc index 22ff9b3e1a8347104bb4f7fdd4ec322c4b3dd6a6..6963d6a06d8203388a3de0fa3bbcd40cdc6a90bf 100644 --- a/paddle/phi/ops/compat/conv2d_sig.cc +++ b/paddle/phi/ops/compat/conv2d_sig.cc @@ -53,9 +53,24 @@ KernelSignature Conv2dDoubleGradOpArgumentMapping( {"DInput", "DFilter", "DDOutput"}); } +KernelSignature Conv2dFusionArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("conv2d_fusion_cutlass", + {"Input", "Filter", "Bias", "ResidualData"}, + {"strides", + "paddings", + "padding_algorithm", + "groups", + "dilations", + "data_format", + "activation", + "fuse_alpha"}, + {"Output"}); +} } // namespace phi PD_REGISTER_ARG_MAPPING_FN(conv2d, phi::Conv2dOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(conv2d_fusion_cutlass, + phi::Conv2dFusionArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(conv2d_grad, phi::Conv2dGradOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(conv2d_grad_grad, phi::Conv2dDoubleGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt index 05d26dde6eddfafa703268cd8904a857487cd71d..cfc83bbcb52047bad575bd0a9911f274d68cadb2 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/ir/inference/CMakeLists.txt @@ -94,6 +94,19 @@ if(WITH_MKLDNN) endforeach() endif() +# below are cutlass unitests +file( + GLOB TEST_CUTLASS + RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" + "test_cutlass_*.py") +string(REPLACE ".py" "" TEST_CUTLASS "${TEST_CUTLASS}") +list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES ${TEST_CUTLASS}) +if(WITH_CUTLASS) + foreach(target ${TEST_CUTLASS}) + py_test_modules(${target} MODULES ${target}) + endforeach() +endif() + if(WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py index b561822f1af92f652f7c8a9851b2d2eee34330df..99450cae46f516ef5af647b667b77789cabd899d 100755 --- a/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/auto_scan_test.py @@ -74,6 +74,8 @@ class IgnoreReasons(enum.Enum): PASS_ACCURACY_ERROR = 2 # Accuracy is abnormal after enabling mkldnn. MKLDNN_ACCURACY_ERROR = 3 + # Accuracy is abnormal after enabling cutlass. + CUTLASS_ACCURACY_ERROR = 3 # TODO(wilber): just for backward compatible @@ -877,3 +879,96 @@ class TrtLayerAutoScanTest(AutoScanTest): note: str, ): self.ignore_cases.append((teller, reason, note)) + + +class CutlassAutoScanTest(AutoScanTest): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def run_test(self, quant=False, *args, **kwargs): + status = True + + for prog_config in self.sample_program_configs(*args, **kwargs): + # if program is invalid, we should skip that cases. + if not self.is_program_valid(prog_config): + continue + + model, params = create_fake_model(prog_config) + feed_data = {} + for name, tensor_config in prog_config.inputs.items(): + feed_data[name] = { + 'data': tensor_config.data, + 'lod': tensor_config.lod, + } + results: List[Dict[str, np.ndarray]] = [] + + # baseline: gpu no ir_optim run + base_config = self.create_inference_config( + ir_optim=False, use_gpu=True + ) + logging.info('RUN program_config: ' + str(prog_config)) + results.append( + self.run_test_config( + model, params, prog_config, base_config, feed_data + ) + ) + self.success_log('RUN_GPU_BASELINE done') + + for pred_config, (atol, rtol) in self.sample_predictor_configs( + prog_config + ): + # skip info + ignore_flag = False + for ignore_info in self.ignore_cases: + if ignore_info[0](prog_config, pred_config): + ignore_flag = True + if ( + ignore_info[1] + == IgnoreReasons.CUTLASS_ACCURACY_ERROR + ): + self.ignore_log( + "[CUTLASS_ACCURACY_ERROR] " + + ignore_info[2] + + ' ' + + ' vs ' + + self.inference_config_str(pred_config) + ) + else: + raise NotImplementedError + break + + if os.path.exists(self.cache_dir): + shutil.rmtree(self.cache_dir) + if not os.path.exists(self.cache_dir): + os.mkdir(self.cache_dir) + + try: + results.append( + self.run_test_config( + model, params, prog_config, pred_config, feed_data + ) + ) + self.assert_tensors_near( + atol, rtol, results[-1], results[0] + ) + except Exception as e: + self.fail_log( + self.inference_config_str(pred_config) + + '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e)) + ) + if not ignore_flag: + status = False + continue + self.success_log( + 'RUN predictor_config ' + + self.inference_config_str(pred_config) + + ' done' + ) + + self.assertTrue(status) + + def inference_config_str(self, config) -> str: + dic = {} + enable_gpu = config.use_gpu() + dic['use_gpu'] = enable_gpu + return str(dic) diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_cutlass_conv2d_fusion_op.py b/python/paddle/fluid/tests/unittests/ir/inference/test_cutlass_conv2d_fusion_op.py new file mode 100644 index 0000000000000000000000000000000000000000..8adeff0f73ddf96ee78ff3d0631547e7259491c8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_cutlass_conv2d_fusion_op.py @@ -0,0 +1,306 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from functools import partial +from itertools import product + +import numpy as np +from auto_scan_test import CutlassAutoScanTest +from program_config import ProgramConfig, TensorConfig + +import paddle.inference as paddle_infer + + +# cba pattern +class TestCutlassConv2dFusionOp1(CutlassAutoScanTest): + def sample_program_configs(self, *args, **kwargs): + def generate_input1(input_shape): + return np.random.random(input_shape).astype(np.float32) + + def generate_weight(weight_shape): + return np.random.random(weight_shape).astype(np.float32) + + def generate_bias(bias_shape): + return np.random.random(bias_shape).astype(np.float32) + + input_shape_options = [[1, 16, 112, 112], [1, 8, 64, 64]] + weight_shape_options = [[24, -1, 3, 3]] + strides_options = [[1, 1], [2, 2]] + paddings_options = [[1, 1], [1, 0, 1, 2]] + groups_options = [1] + padding_algorithm_options = ['EXPLICIT'] + dilations_options = [[2, 2], [1, 1]] + data_format_options = ['NCHW'] + act_options = ['relu', 'leaky_relu', 'swish'] + + configurations = [ + input_shape_options, + weight_shape_options, + strides_options, + paddings_options, + groups_options, + padding_algorithm_options, + dilations_options, + data_format_options, + act_options, + ] + + for ( + input_shape, + weight_shape, + strides, + paddings, + groups, + padding_algorithm, + dilations, + data_format, + act, + ) in product(*configurations): + + weight_shape[1] = input_shape[1] + attrs = [ + { + "strides": strides, + "paddings": paddings, + "groups": groups, + "padding_algorithm": padding_algorithm, + "dilations": dilations, + "data_format": data_format, + }, + {"axis": 1}, + ] + + ops_config = [ + { + "op_type": "conv2d", + "op_inputs": { + "Input": ["input_data"], + "Filter": ["conv2d_weight"], + }, + "op_outputs": {"Output": ["conv_output_data"]}, + "op_attrs": attrs[0], + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["conv_output_data"], + "Y": ["elementwise_weight"], + }, + "op_outputs": {"Out": ["output_data0"]}, + "op_attrs": attrs[1], + }, + { + "op_type": act, + "op_inputs": {"X": ["output_data0"]}, + "op_outputs": {"Out": ["output_data1"]}, + "op_attrs": {}, + }, + ] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "conv2d_weight": TensorConfig( + data_gen=partial(generate_weight, weight_shape) + ), + "elementwise_weight": TensorConfig( + data_gen=partial(generate_bias, [weight_shape[0]]) + ), + }, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input1, input_shape) + ) + }, + outputs=["output_data1"], + ) + + yield program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_gpu=True) + config.enable_use_gpu(256, 0, paddle_infer.PrecisionType.Half) + config.exp_enable_use_cutlass() + yield config, (1e-2, 1e-2) + + def test(self, *args, **kwargs): + self.run_test(quant=False, *args, **kwargs) + + +# cbaa pattern +class TestCutlassConv2dFusionOp2(CutlassAutoScanTest): + def sample_program_configs(self, *args, **kwargs): + def generate_input(input_shape): + return (np.random.random(input_shape) * 2 - 1).astype(np.float32) + + def generate_weight(weight_shape): + return (np.random.random(weight_shape) * 2 - 1).astype(np.float32) + + def generate_bias(bias_shape): + return np.random.random(bias_shape).astype(np.float32) + + input_shape_options = [[1, 16, 112, 112], [1, 24, 64, 64]] + weight_shape_options = [[24, -1, 3, 3]] + strides_options = [[2, 2], [1, 1]] + paddings_options = [[1, 1]] + groups_options = [1] + padding_algorithm_options = ['EXPLICIT'] + dilations_options = [[1, 1]] + data_format_options = ['NCHW'] + act_options = ['relu'] + + configurations = [ + input_shape_options, + weight_shape_options, + strides_options, + paddings_options, + groups_options, + padding_algorithm_options, + dilations_options, + data_format_options, + act_options, + ] + + for ( + input_shape, + weight_shape, + strides, + paddings, + groups, + padding_algorithm, + dilations, + data_format, + act, + ) in product(*configurations): + weight_shape[1] = input_shape[1] + residual_shape = list(input_shape) + residual_shape[1] = weight_shape[0] + + ih = input_shape[2] + iw = input_shape[3] + pad_h0 = 0 + pad_h1 = 0 + pad_w0 = 0 + pad_w1 = 0 + if len(paddings) == 2: + pad_h0 = paddings[0] + pad_h1 = paddings[0] + pad_w0 = paddings[1] + pad_w1 = paddings[1] + elif len(paddings) == 4: + pad_h0 = paddings[0] + pad_h1 = paddings[1] + pad_w0 = paddings[2] + pad_w1 = paddings[3] + dilation_h = dilations[0] + dilation_w = dilations[1] + kh = weight_shape[2] + kw = weight_shape[3] + stride_h = strides[0] + stride_w = strides[1] + residual_shape[2] = (int)( + (ih + pad_h0 + pad_h1 - dilation_h * (kh - 1) - 1) / stride_h + ) + 1 + residual_shape[3] = (int)( + (iw + pad_w0 + pad_w1 - dilation_w * (kw - 1) - 1) / stride_w + ) + 1 + + attrs = [ + { + "strides": strides, + "paddings": paddings, + "groups": groups, + "padding_algorithm": padding_algorithm, + "dilations": dilations, + "data_format": data_format, + }, + {"axis": 1}, + ] + + ops_config = [ + { + "op_type": "conv2d", + "op_inputs": { + "Input": ["input_data"], + "Filter": ["conv2d_weight"], + }, + "op_outputs": {"Output": ["conv_output_data"]}, + "op_attrs": attrs[0], + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["conv_output_data"], + "Y": ["elementwise_weight"], + }, + "op_outputs": {"Out": ["output_data0"]}, + "op_attrs": attrs[1], + }, + { + "op_type": "elementwise_add", + "op_inputs": { + "X": ["residual_data"], + "Y": ["output_data0"], + }, + "op_outputs": {"Out": ["output_data1"]}, + "op_attrs": {}, + }, + { + "op_type": act, + "op_inputs": {"X": ["output_data1"]}, + "op_outputs": {"Out": ["output_data2"]}, + "op_attrs": {}, + }, + ] + + ops = self.generate_op_config(ops_config) + + program_config = ProgramConfig( + ops=ops, + weights={ + "conv2d_weight": TensorConfig( + data_gen=partial(generate_weight, weight_shape) + ), + "elementwise_weight": TensorConfig( + data_gen=partial(generate_bias, [weight_shape[0]]) + ), + }, + inputs={ + "input_data": TensorConfig( + data_gen=partial(generate_input, input_shape) + ), + "residual_data": TensorConfig( + data_gen=partial(generate_input, residual_shape) + ), + }, + outputs=["output_data2"], + ) + + yield program_config + + def sample_predictor_configs(self, program_config): + config = self.create_inference_config(use_gpu=True) + config.enable_use_gpu(256, 0, paddle_infer.PrecisionType.Half) + config.exp_enable_use_cutlass() + yield config, (1e-2, 1e-2) + + def test(self, *args, **kwargs): + self.run_test(quant=False, *args, **kwargs) + + +if __name__ == "__main__": + unittest.main()