未验证 提交 c123dd1e 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference] Implement conv2d_fusion NHWC format using cutlass (#47989)

* Implement conv2d_fusion NHWC format using CUTLASS
* Add unit testing for CUTLASS Conv in inference
* Add experimental API for CUTLASS.
上级 5ac96468
......@@ -105,6 +105,7 @@ pass_library(delete_fill_constant_op_pass inference)
pass_library(constant_folding_pass inference)
pass_library(auto_mixed_precision_pass inference)
pass_library(conv2d_fusion_layout_transfer_pass inference)
pass_library(silu_fuse_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base)
......
......@@ -143,10 +143,16 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
static_cast<phi::DataType>(Get<int>("model_precision")) ==
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_mixed");
bool cutlass_enable = false;
bool cutlass_enable = Get<bool>("use_cutlass");
#ifdef PADDLE_WITH_CUTLASS
cutlass_enable = true;
const auto &prop = platform::GetDeviceProperties(Get<int>("gpu_device_id"));
int sm_version = prop.major * 10 + prop.minor;
// Now we only implement cutlass kernel on SM75.
if (sm_version == 75) {
} else {
cutlass_enable = false;
}
#endif
if (!(is_fp16_precision && cutlass_enable)) return;
......@@ -184,10 +190,21 @@ void Conv2dFusionLayoutTransferPass::ApplyImpl(ir::Graph *graph) const {
auto filter_names = op_node->Op()->Input("Filter");
auto act_type = op_node->Op()->GetAttrIfExists<std::string>("activation");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
std::unordered_set<std::string> cutlass_act_set = {
// conv2d_fusion has two forms: conv + bias + act, conv + bias +
// elmentwise_add + act.
std::unordered_set<std::string> cutlass_cba_act_set = {
"relu", "swish", "identity", "leaky_relu"};
if (!cutlass_act_set.count(act_type)) {
return false;
std::unordered_set<std::string> cutlass_cbaa_act_set = {"relu"};
bool is_residual = op_node->Op()->Input("ResidualData").size() >= 1UL;
if (is_residual) {
if (!cutlass_cbaa_act_set.count(act_type)) {
return false;
}
} else {
if (!cutlass_cba_act_set.count(act_type)) {
return false;
}
}
// If filter's channel is not multiple of 8, conv2d_fusion not run at nhwc.
......
......@@ -167,14 +167,19 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
phi::DataType::FLOAT16 ||
Get<bool>("enable_gpu_mixed");
constexpr int CUTLASS_NHWC_ALIGNMENT = 8;
if (is_fp16_precision) {
bool cutlass_enable = Get<bool>("use_cutlass");
if (is_fp16_precision && cutlass_enable) {
#ifdef PADDLE_WITH_CUTLASS
// cutlass now support these activations
// cutlass_act_set.insert("swish");
// cutlass_act_set.insert("relu");
// cutlass_act_set.insert("identity");
// cutlass_act_set.insert("leaky_relu");
const auto& prop = platform::GetDeviceProperties(Get<int>("gpu_device_id"));
int sm_version = prop.major * 10 + prop.minor;
// Now we only implement cutlass kernel on SM75.
if (sm_version == 75) {
// Cutlass now support these cba activations.
cutlass_act_set.insert("swish");
cutlass_act_set.insert("relu");
cutlass_act_set.insert("identity");
cutlass_act_set.insert("leaky_relu");
}
all_act_set.insert(cutlass_act_set.begin(), cutlass_act_set.end());
#endif
}
......@@ -198,8 +203,8 @@ void ConvElementwiseAddActFusePass::ApplyImpl(ir::Graph* graph) const {
auto* filter_var = scope->FindLocalVar(conv_filter->Name());
auto* filter_tensor = filter_var->GetMutable<phi::DenseTensor>();
CHECK_EQ(filter_tensor->dims().size() == 4UL, true);
// when this conv2d_fusion problem size is not supported by cutlass and not
// supported by cuDNN, we should not apply this pass
// When this conv2d_fusion problem size is not supported by cutlass and not
// supported by cuDNN, we should not apply this pass.
int oc = filter_tensor->dims()[0];
int ic = filter_tensor->dims()[1];
bool cutlass_can_fuse = oc % CUTLASS_NHWC_ALIGNMENT == 0 &&
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/silu_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
namespace ir {
void SiluFusePass::ApplyImpl(ir::Graph* graph) const {
// This pass is used for cutlass, because cutlass can fuse conv + bias + silu
bool cutlass_enable = Get<bool>("use_cutlass");
if (!cutlass_enable) {
return;
}
const std::string pattern_name = "silu_fuse";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* sigmoid_in = gpd.mutable_pattern()->NewNode("sigmoid_in");
auto sigmoid_op =
gpd.mutable_pattern()->NewNode("sigmoid_op")->assert_is_op("sigmoid");
auto sigmoid_out = gpd.mutable_pattern()
->NewNode("sigmoid_out")
->assert_is_op_output("sigmoid")
->AsIntermediate();
auto elementwise_mul_op = gpd.mutable_pattern()
->NewNode("elementwise_mul_op")
->assert_is_op("elementwise_mul");
auto elementwise_mul_out = gpd.mutable_pattern()
->NewNode("elementwise_mul_out")
->assert_is_op_output("elementwise_mul")
->AsOutput();
sigmoid_op->LinksFrom({sigmoid_in}).LinksTo({sigmoid_out});
elementwise_mul_op->LinksFrom({sigmoid_in, sigmoid_out})
.LinksTo({elementwise_mul_out});
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
Node* sigmoid_in_node = subgraph.at(sigmoid_in);
Node* sigmoid_op_node = subgraph.at(sigmoid_op);
Node* elementwise_mul_op_node = subgraph.at(elementwise_mul_op);
Node* elementwise_mul_out_node = subgraph.at(elementwise_mul_out);
OpDesc new_desc;
new_desc.SetType("swish");
new_desc.SetAttr("beta", 1.f);
new_desc.SetInput("X", {sigmoid_in_node->Name()});
new_desc.SetOutput("Out", {elementwise_mul_out_node->Name()});
new_desc.Flush();
std::unordered_set<const Node*> del_node_set;
del_node_set.insert(sigmoid_op_node);
del_node_set.insert(elementwise_mul_op_node);
GraphSafeRemoveNodes(graph, del_node_set);
auto fused_node = graph->CreateOpNode(&new_desc);
IR_NODE_LINK_TO(sigmoid_in_node, fused_node);
IR_NODE_LINK_TO(fused_node, elementwise_mul_out_node);
};
gpd(graph, handler);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(silu_fuse_pass, paddle::framework::ir::SiluFusePass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class SiluFusePass : public FusePassBase {
public:
virtual ~SiluFusePass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -202,6 +202,7 @@ struct Argument {
// Passed from config.
DECL_ARGUMENT_FIELD(use_gpu, UseGPU, bool);
DECL_ARGUMENT_FIELD(use_cutlass, UseCutlass, bool);
DECL_ARGUMENT_FIELD(use_fc_padding, UseFcPadding, bool);
DECL_ARGUMENT_FIELD(gpu_device_id, GPUDeviceId, int);
......
......@@ -52,6 +52,7 @@ void IRPassManager::CreatePasses(Argument *argument,
for (const std::string &pass_name : passes) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_name);
pass->Set("use_varseqlen", new bool(argument->tensorrt_use_varseqlen()));
pass->Set("use_cutlass", new bool(argument->use_cutlass()));
pass->Set("with_interleaved",
new bool(argument->tensorrt_with_interleaved()));
pass->Set("tensorrt_transformer_posid",
......@@ -80,6 +81,10 @@ void IRPassManager::CreatePasses(Argument *argument,
pass->Set("optim_shape_tensor",
new std::map<std::string, std::vector<int>>());
// This gpu_device_id is used by some fp16 precision passes, so move it
// here.
pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
// tuned trt dynamic_shape
pass->Set("trt_tuned_dynamic_shape",
new bool(argument->tensorrt_tuned_dynamic_shape()));
......@@ -198,7 +203,6 @@ void IRPassManager::CreatePasses(Argument *argument,
"model_opt_cache_dir",
new std::string(GetOrCreateModelOptCacheDir(model_opt_cache_dir)));
}
pass->Set("gpu_device_id", new int(argument->gpu_device_id()));
pass->Set("use_static_engine", new bool(use_static_engine));
pass->Set("model_from_memory", new bool(argument->model_from_memory()));
pass->Set("use_inspector", new bool(argument->tensorrt_use_inspector()));
......
......@@ -115,6 +115,17 @@ void AnalysisConfig::EnableUseGpu(uint64_t memory_pool_init_size_mb,
Update();
}
void AnalysisConfig::Exp_EnableUseCutlass() {
#if defined(PADDLE_WITH_CUTLASS)
use_cutlass_ = true;
#else
LOG(ERROR) << "Please compile with cutlass to EnableUseCutlass()";
use_cutlass_ = false;
#endif
Update();
}
void AnalysisConfig::SetExecStream(void *stream) {
PADDLE_ENFORCE_NOT_NULL(
stream,
......@@ -389,6 +400,7 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(use_fc_padding_);
// GPU related.
CP_MEMBER(use_gpu_);
CP_MEMBER(use_cutlass_);
CP_MEMBER(use_external_stream_);
CP_MEMBER(exec_stream_);
CP_MEMBER(use_cudnn_);
......@@ -1249,6 +1261,7 @@ std::string AnalysisConfig::Summary() {
// gpu info
os.InsertRow({"use_gpu", use_gpu_ ? "true" : "false"});
if (use_gpu_) {
os.InsertRow({"use_cutlass", use_cutlass_ ? "true" : "false"});
os.InsertRow({"gpu_device_id", std::to_string(gpu_device_id_)});
os.InsertRow({"enable_gpu_mixed", std::to_string(enable_gpu_mixed_)});
os.InsertRow({"memory_pool_init_size",
......
......@@ -1088,6 +1088,7 @@ void AnalysisPredictor::PrepareArgument() {
// Init std::unique_ptr argument_.
argument_.reset(new Argument);
argument_->SetUseGPU(config_.use_gpu());
argument_->SetUseCutlass(config_.use_cutlass_);
argument_->SetUseFcPadding(config_.use_fc_padding());
argument_->SetGPUDeviceId(config_.gpu_device_id());
argument_->SetEnableIrOptim(config_.enable_ir_optim_);
......
......@@ -395,6 +395,12 @@ struct PD_INFER_DECL AnalysisConfig {
///
bool use_gpu() const { return use_gpu_; }
///
/// \brief When running the fp16 model on Nvidia GPU, you can also try running
/// your model on cutlass.
///
void Exp_EnableUseCutlass();
///
///
/// \brief A boolean state telling whether the XPU is turned on.
///
/// \return bool Whether the XPU is turned on.
......@@ -1047,6 +1053,7 @@ struct PD_INFER_DECL AnalysisConfig {
// GPU related.
bool use_gpu_{false};
bool use_cutlass_{false};
int gpu_device_id_{0};
uint64_t memory_pool_init_size_mb_{100}; // initial size is 100MB.
bool enable_gpu_mixed_{false};
......
......@@ -164,6 +164,7 @@ const std::vector<std::string> kLiteSubgraphPasses({
const std::vector<std::string> kGpuLowerPrecisionPasses{
"identity_scale_op_clean_pass",
"simplify_with_basic_ops_pass",
"silu_fuse_pass",
"delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_pass",
"map_depthwise_conv_to_conv_pass",
......@@ -172,6 +173,7 @@ const std::vector<std::string> kGpuLowerPrecisionPasses{
"conv_elementwise_add_act_fuse_pass",
"conv_elementwise_add2_act_fuse_pass",
"conv_elementwise_add_fuse_pass",
"conv2d_fusion_layout_transfer_pass",
"multihead_matmul_fuse_pass_v2",
"fused_multi_transformer_encoder_pass",
"fused_multi_transformer_decoder_pass",
......@@ -216,6 +218,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"delete_weight_dequant_linear_op_pass", //
"map_depthwise_conv_to_conv_pass", //
"constant_folding_pass", //
"silu_fuse_pass", //
"conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
......@@ -250,7 +253,8 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
#endif //
"transpose_flatten_concat_fuse_pass", //
"constant_folding_pass", //
"auto_mixed_precision_pass", //
"conv2d_fusion_layout_transfer_pass", //
"auto_mixed_precision_pass"
});
use_gpu_ = true;
......
......@@ -330,3 +330,13 @@ REGISTER_OPERATOR(
ops::ConvOpInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
// This op is used by cutlass, conv2d_fusion_cutlass is a intermediate op
// produced by conv2d_fusion_layout_transfer_pass.
REGISTER_OPERATOR(
conv2d_fusion_cutlass,
ops::Conv2DFusionOp,
ops::Conv2DFusionOpMaker,
ops::ConvOpInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
......@@ -646,6 +646,7 @@ void BindAnalysisConfig(py::module *m) {
py::arg("memory_pool_init_size_mb"),
py::arg("device_id") = 0,
py::arg("precision_mode") = AnalysisConfig::Precision::kFloat32)
.def("exp_enable_use_cutlass", &AnalysisConfig::Exp_EnableUseCutlass)
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
.def("set_exec_stream",
[](AnalysisConfig &self, phi::CUDAStream &stream) {
......
......@@ -106,8 +106,7 @@ file(
"fusion/gpu/*.cu")
if(WITH_CUTLASS)
file(GLOB cutlass_cu "fusion/cutlass/default_moe_fc_traits.h"
"fusion/cutlass/linear_combination_ft_gelu.h" "fusion/cutlass/moe*")
file(GLOB cutlass_cu "fusion/cutlass/conv2d/*.cu" "fusion/cutlass/*.cu")
list(APPEND kernel_cu ${cutlass_cu})
endif()
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp =
cutlass::epilogue::thread::LinearCombination<ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 1
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 2
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 3
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(
ConvAllParams);
// config 4
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(
ConvAllParams);
// config 5
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(
ConvAllParams);
// config 6
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(
ConvAllParams);
// config 7
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(
ConvAllParams);
// config 8
template cutlass::Status Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(
ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_all_func = {
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias;
std::mutex conv2d_bias_mutex;
void Conv2dBias(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias.count(problem_size)) {
conv2d_bias_all_func[map_problem_conv2d_bias.at(problem_size)](params);
return;
}
int best_config_index =
ProfileToGetBestConfig(conv2d_bias_all_func, params, CONV2D_BIAS);
std::lock_guard<std::mutex> guard(conv2d_bias_mutex);
map_problem_conv2d_bias[problem_size] = best_config_index;
conv2d_bias_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h"
#include "cutlass/epilogue/thread/linear_combination_residual_block.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasAddReluImpl(ConvAllParams params) {
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock<
cutlass::half_t,
float,
float,
cutlass::half_t,
Alignment,
cutlass::epilogue::thread::Identity,
cutlass::plus,
cutlass::epilogue::thread::ReLu>;
using Conv2dFpropKernel =
typename cutlass::conv::kernel::DefaultConv2dFpropWithBroadcast<
cutlass::half_t,
cutlass::layout::TensorNHWC,
cutlass::half_t,
cutlass::layout::TensorNHWC,
cutlass::half_t,
cutlass::layout::TensorNHWC,
float,
cutlass::arch::OpClassTensorOp,
cutlass::arch::Sm75,
TShape,
WShape,
cutlass::gemm::GemmShape<16, 8, 8>,
EpilogueOp,
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>,
2,
cutlass::arch::OpMultiplyAdd,
cutlass::conv::IteratorAlgorithm::kOptimized,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
const half *residual = params.residual;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Conv2dProblemSize problem_size(
{batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
cutlass::conv::Mode::kCrossCorrelation,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)input, {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)weight, {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)residual, {oc, oc * ow, oc * ow * oh}},
{(cutlass::half_t *)output, {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f},
cutlass::conv::SplitKMode::kSerial,
(cutlass::half_t *)(bias),
nullptr,
0,
oc};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
// config 9
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 10
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 11
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 12
template cutlass::Status
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_add_relu_all_func = {
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<128, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 64, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasAddReluImpl<cutlass::gemm::GemmShape<256, 128, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_add_relu;
std::mutex conv2d_bias_add_relu_mutex;
void Conv2dBiasAddRelu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_add_relu.count(problem_size)) {
conv2d_bias_add_relu_all_func[map_problem_conv2d_bias_add_relu.at(
problem_size)](params);
return;
}
std::lock_guard<std::mutex> guard(conv2d_bias_add_relu_mutex);
// config 6's diff is large.
conv2d_bias_add_relu_all_func[6] = nullptr;
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_add_relu_all_func, params, CONV2D_BIAS_ADD_RELU);
map_problem_conv2d_bias_add_relu[problem_size] = best_config_index;
conv2d_bias_add_relu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_leaky_relu.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasLeakyReluImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp = cutlass::epilogue::thread::LinearCombinationLeakyRelu<
ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
float alpha = params.alpha;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f, alpha}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status Conv2dBiasLeakyReluImpl<
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_leaky_relu_all_func = {
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasLeakyReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_leaky_relu;
std::mutex conv2d_bias_leaky_relu_mutex;
void Conv2dBiasLeakyRelu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_leaky_relu.count(problem_size)) {
conv2d_bias_leaky_relu_all_func[map_problem_conv2d_bias_leaky_relu.at(
problem_size)](params);
return;
}
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_leaky_relu_all_func, params, CONV2D_BIAS_LEAKY_RELU);
std::lock_guard<std::mutex> guard(conv2d_bias_leaky_relu_mutex);
map_problem_conv2d_bias_leaky_relu[problem_size] = best_config_index;
conv2d_bias_leaky_relu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasReluImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp =
cutlass::epilogue::thread::LinearCombinationRelu<ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_relu_all_func = {
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasReluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_relu;
std::mutex conv2d_bias_relu_mutex;
void Conv2dBiasRelu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_relu.count(problem_size)) {
conv2d_bias_relu_all_func[map_problem_conv2d_bias_relu.at(problem_size)](
params);
return;
}
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_relu_all_func, params, CONV2D_BIAS_RELU);
std::lock_guard<std::mutex> guard(conv2d_bias_relu_mutex);
map_problem_conv2d_bias_relu[problem_size] = best_config_index;
conv2d_bias_relu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_bias_relu.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 1>
cutlass::Status Conv2dBiasReluFewChannelsImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kFewChannels;
using EpilogueOp =
cutlass::epilogue::thread::LinearCombinationRelu<ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w1;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status Conv2dBiasReluFewChannelsImpl<
cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_relu_few_channels_all_func = {
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasReluFewChannelsImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_relu_few_channels;
void Conv2dBiasReluFewChannels(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w1;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_relu_few_channels.count(problem_size)) {
conv2d_bias_relu_few_channels_all_func
[map_problem_conv2d_bias_relu_few_channels.at(problem_size)](params);
return;
}
//
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <mutex>
#include "cutlass/conv/kernel/default_conv2d_fprop.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename TShape, typename WShape, int Alignment = 8>
cutlass::Status Conv2dBiasSiluImpl(ConvAllParams params) {
using ElementAccumulator = float;
using ElementComputeEpilogue = float;
using ElementInputA = cutlass::half_t;
using ElementInputB = cutlass::half_t;
using ElementOutput = cutlass::half_t;
using LayoutInputA = cutlass::layout::TensorNHWC;
using LayoutInputB = cutlass::layout::TensorNHWC;
using LayoutOutput = cutlass::layout::TensorNHWC;
using MMAOp = cutlass::arch::OpClassTensorOp;
using SmArch = cutlass::arch::Sm75;
using ThreadblockShape = TShape;
using WarpShape = WShape;
using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>;
using SwizzleThreadBlock =
cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<4>;
constexpr int NumStages = 2;
static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm =
cutlass::conv::IteratorAlgorithm::kOptimized;
using EpilogueOp =
cutlass::epilogue::thread::LinearCombinationSilu<ElementOutput,
Alignment,
float,
ElementComputeEpilogue>;
using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop<
ElementInputA,
LayoutInputA,
ElementInputB,
LayoutInputB,
ElementOutput,
LayoutOutput,
ElementAccumulator,
MMAOp,
SmArch,
ThreadblockShape,
WarpShape,
InstructionShape,
EpilogueOp,
SwizzleThreadBlock,
NumStages,
cutlass::arch::OpMultiplyAdd,
IteratorAlgorithm,
cutlass::conv::StrideSupport::kStrided,
Alignment,
Alignment>::Kernel;
using ImplicitGemm =
cutlass::conv::device::ImplicitGemmConvolution<Conv2dFpropKernel>;
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int oh = params.oh;
int ow = params.ow;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation;
cutlass::conv::Conv2dProblemSize problem_size({batch, ih, iw, ic},
{oc, kh, kw, ic},
{pad_h0, 0, pad_w0, 0},
{stride_h, stride_w},
{dilation_h, dilation_w},
{batch, oh, ow, oc},
mode,
1);
typename ImplicitGemm::Arguments arguments{
problem_size,
{(cutlass::half_t *)(input), {ic, ic * iw, ic * iw * ih}},
{(cutlass::half_t *)(weight), {ic, ic * kw, ic * kw * kh}},
{(cutlass::half_t *)(bias), {0, 0, 0}},
{(cutlass::half_t *)(output), {oc, oc * ow, oc * ow * oh}},
{1.f, 1.f}};
ImplicitGemm implicit_gemm_op;
size_t bytes = implicit_gemm_op.get_workspace_size(arguments);
auto ctx = params.ctx;
auto stream = ctx->stream();
paddle::memory::allocation::AllocationPtr tmp_gpu_ptrs_data =
paddle::memory::Alloc(
ctx->GetPlace(),
bytes,
phi::Stream(reinterpret_cast<phi::StreamId>(stream)));
void *workspace = tmp_gpu_ptrs_data->ptr();
cutlass::Status status = implicit_gemm_op.can_implement(arguments);
CUTLASS_CHECK(status);
status = implicit_gemm_op.initialize(arguments, workspace);
CUTLASS_CHECK(status);
status = implicit_gemm_op(stream);
CUTLASS_CHECK(status);
return status;
}
// config 0
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 1
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 2
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 3
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>(ConvAllParams);
// config 4
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>(ConvAllParams);
// config 5
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>(ConvAllParams);
// config 6
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 7
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>(ConvAllParams);
// config 8
template cutlass::Status
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>(ConvAllParams);
std::vector<std::function<cutlass::Status(ConvAllParams)>>
conv2d_bias_silu_all_func = {
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 32, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 64>,
cutlass::gemm::GemmShape<32, 32, 64>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 64, 32>,
cutlass::gemm::GemmShape<32, 32, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 32>,
cutlass::gemm::GemmShape<32, 64, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 128, 64>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<64, 256, 32>,
cutlass::gemm::GemmShape<64, 64, 32>>,
Conv2dBiasSiluImpl<cutlass::gemm::GemmShape<128, 64, 32>,
cutlass::gemm::GemmShape<64, 32, 32>>};
std::map<std::vector<int>, int> map_problem_conv2d_bias_silu;
std::mutex conv2d_bias_silu_mutex;
void Conv2dBiasSilu(ConvAllParams params) {
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h0 = params.pad_h0;
int pad_w0 = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
std::vector<int> problem_size = {
batch, ic, ih, iw, kh, kw, oc, pad_h0, pad_w0, stride_h, stride_w};
if (map_problem_conv2d_bias_silu.count(problem_size)) {
conv2d_bias_silu_all_func[map_problem_conv2d_bias_silu.at(problem_size)](
params);
return;
}
int best_config_index = ProfileToGetBestConfig(
conv2d_bias_silu_all_func, params, CONV2D_BIAS_SILU);
std::lock_guard<std::mutex> guard(conv2d_bias_silu_mutex);
map_problem_conv2d_bias_silu[problem_size] = best_config_index;
conv2d_bias_silu_all_func[best_config_index](params);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_fp16.h>
#include <glog/logging.h>
#include <map>
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
typedef struct {
const half *input;
const half *weight;
const half *bias;
const half *residual;
half *output;
int batch;
int ic;
int ih;
int iw;
int kh;
int kw;
int oc;
int pad_h0;
int pad_h1;
int pad_w0;
int pad_w1;
int stride_h;
int stride_w;
int dilation_h;
int dilation_w;
int oh;
int ow;
const phi::GPUContext *ctx;
float alpha; // for leaky_relu use
} ConvAllParams;
// Below functions are provided by cutlass, they are called by phi.
void Conv2dBiasAddRelu(ConvAllParams params);
void Conv2dBiasRelu(ConvAllParams params);
void Conv2dBiasLeakyRelu(ConvAllParams params);
void Conv2dBiasSilu(ConvAllParams params);
void Conv2dBias(ConvAllParams params);
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_util.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
struct logical_coord {
int n;
int c;
int h;
int w;
};
float diff(const half *c, const float *c_baseline, int n) {
float max_diff = -1.;
for (int i = 0; i < n; i++) {
float c_value = __half2float(c[i]);
if (std::abs(c_baseline[i] - c_value) > max_diff) {
max_diff = std::abs(c_baseline[i] - c_value);
}
}
return max_diff;
}
__device__ int gpu_nhwc(struct logical_coord shape,
struct logical_coord index) {
return index.n * shape.h * shape.w * shape.c + index.h * shape.w * shape.c +
index.w * shape.c + index.c;
}
__global__ void naive_conv2d_kernel(const half *input,
const half *weight,
const half *bias,
float *output,
int batch,
int ic,
int ih,
int iw,
int kh,
int kw,
int oc,
int pad_h,
int pad_w,
int stride_h,
int stride_w,
int dilation_h,
int dilation_w,
int oh,
int ow,
const half *residual,
float alpha, // for leaky_relu
OpType op_type) {
int M = batch * oh * ow;
int N = oc;
int K = ic * kh * kw;
int m_i = threadIdx.x + blockIdx.x * blockDim.x;
int n_i = threadIdx.y + blockIdx.y * blockDim.y;
if (m_i >= M || n_i >= N) return;
int batch_i = m_i / (oh * ow);
int oh_i = (m_i % (oh * ow)) / ow;
int ow_i = (m_i % (oh * ow)) % ow;
int oc_i = n_i;
struct logical_coord weight_shape = {oc, ic, kh, kw};
struct logical_coord input_shape = {batch, ic, ih, iw};
int out_offset = m_i * N + n_i;
float *out_ptr = output + out_offset;
float sum = 0.f;
for (int k_i = 0; k_i < K; k_i++) {
int ic_i = k_i / (kh * kw);
int kh_i = (k_i % (kh * kw)) / kw;
int kw_i = (k_i % (kh * kw)) % kw;
struct logical_coord weight_index = {oc_i, ic_i, kh_i, kw_i};
int ih_i = oh_i * stride_h - pad_h + kh_i * dilation_h;
int iw_i = ow_i * stride_w - pad_w + kw_i * dilation_w;
if (ih_i < 0 || ih_i >= ih) continue;
if (iw_i < 0 || iw_i >= iw) continue;
struct logical_coord input_index = {batch_i, ic_i, ih_i, iw_i};
const half *weight_ptr = weight + gpu_nhwc(weight_shape, weight_index);
const half *in_ptr = input + gpu_nhwc(input_shape, input_index);
sum += __half2float(*in_ptr) * __half2float(*weight_ptr);
}
sum += __half2float(*(bias + oc_i));
float x = sum;
switch (op_type) {
case CONV2D_BIAS:
*out_ptr = x;
break;
case CONV2D_BIAS_RELU:
*out_ptr = x > 0 ? x : 0;
break;
case CONV2D_BIAS_SILU:
*out_ptr = x * (1.f / (1 + exp(-x)));
break;
case CONV2D_BIAS_ADD_RELU:
x += __half2float(*(residual + out_offset));
*out_ptr = x > 0 ? x : 0;
break;
case CONV2D_BIAS_LEAKY_RELU:
*out_ptr = x > 0 ? x : (x * alpha);
break;
default:
break;
}
}
float conv2d_diff_gpu(ConvAllParams params, OpType op_type) {
const half *input = params.input;
const half *weight = params.weight;
const half *bias = params.bias;
half *output = params.output;
int batch = params.batch;
int ic = params.ic;
int ih = params.ih;
int iw = params.iw;
int kh = params.kh;
int kw = params.kw;
int oc = params.oc;
int pad_h = params.pad_h0;
int pad_w = params.pad_w0;
int stride_h = params.stride_h;
int stride_w = params.stride_w;
int dilation_h = params.dilation_h;
int dilation_w = params.dilation_w;
const half *residual = params.residual;
int oh = params.oh;
int ow = params.ow;
int M = batch * oh * ow;
int N = oc;
constexpr int blockM = 16;
constexpr int blockN = 16;
uint3 grid = {(M + blockM - 1) / blockM, (N + blockN - 1) / blockN, 1};
uint3 block = {blockM, blockN, 1};
int output_size = batch * oc * oh * ow;
half *output_from_cutlass =
reinterpret_cast<half *>(malloc(sizeof(half) * output_size));
cudaMemcpy(output_from_cutlass,
output,
output_size * sizeof(half),
cudaMemcpyDeviceToHost);
float *gpu_output;
cudaMalloc(&gpu_output, output_size * sizeof(float));
naive_conv2d_kernel<<<grid, block>>>(input,
weight,
bias,
gpu_output,
batch,
ic,
ih,
iw,
kh,
kw,
oc,
pad_h,
pad_w,
stride_h,
stride_w,
dilation_h,
dilation_w,
oh,
ow,
residual,
params.alpha,
op_type);
float *output_from_gpu =
reinterpret_cast<float *>(malloc(sizeof(float) * output_size));
cudaMemcpy(output_from_gpu,
gpu_output,
output_size * sizeof(float),
cudaMemcpyDeviceToHost);
float max_diff = diff(output_from_cutlass, output_from_gpu, output_size);
free(output_from_cutlass);
free(output_from_gpu);
cudaFree(gpu_output);
return max_diff;
}
std::string OpType2String(OpType op_type) {
switch (op_type) {
case CONV2D_BIAS:
return "conv2d_bias";
break;
case CONV2D_BIAS_RELU:
return "conv2d_bias_relu";
break;
case CONV2D_BIAS_SILU:
return "conv2d_bias_add_silu";
break;
case CONV2D_BIAS_ADD_RELU:
return "conv2d_bias_add_relu";
break;
case CONV2D_BIAS_LEAKY_RELU:
return "conv2d_bias_leaky_relu";
default:
break;
}
return "unnamed_op";
}
int ProfileToGetBestConfig(
const std::vector<std::function<cutlass::Status(ConvAllParams)>> &all_func,
ConvAllParams params,
OpType op_type) {
constexpr int WARMUP = 10;
constexpr int REPEAT = 100;
float min_time = 100000.f;
int min_time_index = -1;
for (int i = 0; i < all_func.size(); i++) {
cutlass::Status status;
auto func = all_func[i];
// When func has large diff, we will make it nullptr.
if (!func) continue;
for (int ii = 0; ii < WARMUP; ii++) {
status = func(params);
}
cudaEvent_t beg, end;
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&beg));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventCreate(&end));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(beg));
for (int ii = 0; ii < REPEAT; ii++) {
status = func(params);
}
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventRecord(end));
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventSynchronize(end));
float elapsed_time;
PADDLE_ENFORCE_GPU_SUCCESS(cudaEventElapsedTime(&elapsed_time, beg, end));
if (elapsed_time < min_time && status == cutlass::Status::kSuccess) {
min_time = elapsed_time;
min_time_index = i;
}
// debug code
VLOG(3) << OpType2String(op_type) << ": tactic " << i << " has max diff "
<< conv2d_diff_gpu(params, op_type) << " compared with baseline.";
}
if (min_time_index < 0) {
PADDLE_THROW(
phi::errors::NotFound("Can't find any cutlass config for this %s op.",
OpType2String(op_type).c_str()));
}
return min_time_index;
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cuda_fp16.h>
#include <vector>
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/conv/device/implicit_gemm_convolution.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/enforce.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
#define CUTLASS_CHECK(status) \
if (status != cutlass::Status::kSuccess) { \
VLOG(3) \
<< "Cutlass can not deal with this problem size, skip this kernel!"; \
return status; \
}
typedef enum {
CONV2D_BIAS,
CONV2D_BIAS_RELU,
CONV2D_BIAS_ADD_RELU,
CONV2D_BIAS_SILU,
CONV2D_BIAS_LEAKY_RELU
} OpType;
// conv2d_diff_gpu calculate diff of cutlass output and baseline output, you can
// use them to debug. return value is the max diff between cutlass and baseline.
float conv2d_diff_gpu(ConvAllParams params, OpType op_type);
int ProfileToGetBestConfig(
const std::vector<std::function<cutlass::Status(ConvAllParams)>>& all_func,
ConvAllParams params,
OpType op_type);
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fusion/cutlass/conv2d/conv2d_decl.h"
namespace phi {
namespace fusion {
namespace cutlass_internal {
template <typename T, typename Context>
void Conv2dFusionKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& filter,
const DenseTensor& bias,
const paddle::optional<DenseTensor>& residual,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const std::string& padding_algorithm,
int groups,
const std::vector<int>& dilations,
const std::string& data_format,
const std::string& activation,
float fuse_alpha,
DenseTensor* output) {
ctx.template Alloc<T>(output);
auto in_dims = x.dims();
auto filter_dims = filter.dims();
auto out_dims = output->dims();
CHECK_EQ(in_dims.size() == 4UL, true);
CHECK_EQ(filter_dims.size() == 4UL, true);
CHECK_EQ(strides.size() == 2UL, true);
CHECK_EQ(dilations.size() == 2UL, true);
CHECK_EQ(groups == 1, true);
CHECK_EQ(padding_algorithm == "EXPLICIT", true);
const int batch = in_dims[0];
const int ic = in_dims[3];
const int ih = in_dims[1];
const int iw = in_dims[2];
int pad_h0 = 0;
int pad_h1 = 0;
int pad_w0 = 0;
int pad_w1 = 0;
if (paddings.size() == 2UL) {
pad_h0 = paddings[0];
pad_h1 = paddings[0];
pad_w0 = paddings[1];
pad_w1 = paddings[1];
} else if (paddings.size() == 4UL) {
pad_h0 = paddings[0];
pad_h1 = paddings[1];
pad_w0 = paddings[2];
pad_w1 = paddings[3];
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Attr paddins in conv2d_fusion must have 2 or 4 elements, but now have "
"%u elements.",
paddings.size()));
}
const int stride_h = strides[0];
const int stride_w = strides[1];
const int dilation_h = dilations[0];
const int dilation_w = dilations[1];
const int oc = filter_dims[0];
const int kh = filter_dims[1];
const int kw = filter_dims[2];
CHECK_EQ(out_dims.size() == 4UL, true);
const int oh = out_dims[1];
const int ow = out_dims[2];
ConvAllParams params = {reinterpret_cast<const half*>(x.data<T>()),
reinterpret_cast<const half*>(filter.data<T>()),
reinterpret_cast<const half*>(bias.data<T>()),
nullptr,
reinterpret_cast<half*>(output->data<T>()),
batch,
ic,
ih,
iw,
kh,
kw,
oc,
pad_h0,
pad_h1,
pad_w0,
pad_w1,
stride_h,
stride_w,
dilation_h,
dilation_w,
oh,
ow,
&ctx};
if (residual) {
if (activation == "relu") {
params.residual = reinterpret_cast<const half*>(residual->data<T>());
Conv2dBiasAddRelu(params);
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Cutlass now only support relu activation in a residual block"));
}
} else if (activation == "relu") {
Conv2dBiasRelu(params);
} else if (activation == "swish") {
Conv2dBiasSilu(params);
} else if (activation == "identity") {
Conv2dBias(params);
} else if (activation == "leaky_relu") {
params.alpha = fuse_alpha;
Conv2dBiasLeakyRelu(params);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Cutlass does not support this activation: %s.", activation.c_str()));
}
output->set_layout(DataLayout::NHWC);
}
} // namespace cutlass_internal
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(conv2d_fusion_cutlass,
GPU,
ALL_LAYOUT,
phi::fusion::cutlass_internal::Conv2dFusionKernel,
float,
phi::dtype::float16) {}
......@@ -53,9 +53,24 @@ KernelSignature Conv2dDoubleGradOpArgumentMapping(
{"DInput", "DFilter", "DDOutput"});
}
KernelSignature Conv2dFusionArgumentMapping(const ArgumentMappingContext& ctx) {
return KernelSignature("conv2d_fusion_cutlass",
{"Input", "Filter", "Bias", "ResidualData"},
{"strides",
"paddings",
"padding_algorithm",
"groups",
"dilations",
"data_format",
"activation",
"fuse_alpha"},
{"Output"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(conv2d, phi::Conv2dOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv2d_fusion_cutlass,
phi::Conv2dFusionArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv2d_grad, phi::Conv2dGradOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(conv2d_grad_grad,
phi::Conv2dDoubleGradOpArgumentMapping);
......@@ -94,6 +94,19 @@ if(WITH_MKLDNN)
endforeach()
endif()
# below are cutlass unitests
file(
GLOB TEST_CUTLASS
RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}"
"test_cutlass_*.py")
string(REPLACE ".py" "" TEST_CUTLASS "${TEST_CUTLASS}")
list(REMOVE_ITEM TEST_INFERENCE_IR_PASSES ${TEST_CUTLASS})
if(WITH_CUTLASS)
foreach(target ${TEST_CUTLASS})
py_test_modules(${target} MODULES ${target})
endforeach()
endif()
if(WITH_MKLDNN
AND TENSORRT_FOUND
AND WITH_GPU)
......
......@@ -74,6 +74,8 @@ class IgnoreReasons(enum.Enum):
PASS_ACCURACY_ERROR = 2
# Accuracy is abnormal after enabling mkldnn.
MKLDNN_ACCURACY_ERROR = 3
# Accuracy is abnormal after enabling cutlass.
CUTLASS_ACCURACY_ERROR = 3
# TODO(wilber): just for backward compatible
......@@ -877,3 +879,96 @@ class TrtLayerAutoScanTest(AutoScanTest):
note: str,
):
self.ignore_cases.append((teller, reason, note))
class CutlassAutoScanTest(AutoScanTest):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def run_test(self, quant=False, *args, **kwargs):
status = True
for prog_config in self.sample_program_configs(*args, **kwargs):
# if program is invalid, we should skip that cases.
if not self.is_program_valid(prog_config):
continue
model, params = create_fake_model(prog_config)
feed_data = {}
for name, tensor_config in prog_config.inputs.items():
feed_data[name] = {
'data': tensor_config.data,
'lod': tensor_config.lod,
}
results: List[Dict[str, np.ndarray]] = []
# baseline: gpu no ir_optim run
base_config = self.create_inference_config(
ir_optim=False, use_gpu=True
)
logging.info('RUN program_config: ' + str(prog_config))
results.append(
self.run_test_config(
model, params, prog_config, base_config, feed_data
)
)
self.success_log('RUN_GPU_BASELINE done')
for pred_config, (atol, rtol) in self.sample_predictor_configs(
prog_config
):
# skip info
ignore_flag = False
for ignore_info in self.ignore_cases:
if ignore_info[0](prog_config, pred_config):
ignore_flag = True
if (
ignore_info[1]
== IgnoreReasons.CUTLASS_ACCURACY_ERROR
):
self.ignore_log(
"[CUTLASS_ACCURACY_ERROR] "
+ ignore_info[2]
+ ' '
+ ' vs '
+ self.inference_config_str(pred_config)
)
else:
raise NotImplementedError
break
if os.path.exists(self.cache_dir):
shutil.rmtree(self.cache_dir)
if not os.path.exists(self.cache_dir):
os.mkdir(self.cache_dir)
try:
results.append(
self.run_test_config(
model, params, prog_config, pred_config, feed_data
)
)
self.assert_tensors_near(
atol, rtol, results[-1], results[0]
)
except Exception as e:
self.fail_log(
self.inference_config_str(pred_config)
+ '\033[1;31m \nERROR INFO: {}\033[0m'.format(str(e))
)
if not ignore_flag:
status = False
continue
self.success_log(
'RUN predictor_config '
+ self.inference_config_str(pred_config)
+ ' done'
)
self.assertTrue(status)
def inference_config_str(self, config) -> str:
dic = {}
enable_gpu = config.use_gpu()
dic['use_gpu'] = enable_gpu
return str(dic)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
from functools import partial
from itertools import product
import numpy as np
from auto_scan_test import CutlassAutoScanTest
from program_config import ProgramConfig, TensorConfig
import paddle.inference as paddle_infer
# cba pattern
class TestCutlassConv2dFusionOp1(CutlassAutoScanTest):
def sample_program_configs(self, *args, **kwargs):
def generate_input1(input_shape):
return np.random.random(input_shape).astype(np.float32)
def generate_weight(weight_shape):
return np.random.random(weight_shape).astype(np.float32)
def generate_bias(bias_shape):
return np.random.random(bias_shape).astype(np.float32)
input_shape_options = [[1, 16, 112, 112], [1, 8, 64, 64]]
weight_shape_options = [[24, -1, 3, 3]]
strides_options = [[1, 1], [2, 2]]
paddings_options = [[1, 1], [1, 0, 1, 2]]
groups_options = [1]
padding_algorithm_options = ['EXPLICIT']
dilations_options = [[2, 2], [1, 1]]
data_format_options = ['NCHW']
act_options = ['relu', 'leaky_relu', 'swish']
configurations = [
input_shape_options,
weight_shape_options,
strides_options,
paddings_options,
groups_options,
padding_algorithm_options,
dilations_options,
data_format_options,
act_options,
]
for (
input_shape,
weight_shape,
strides,
paddings,
groups,
padding_algorithm,
dilations,
data_format,
act,
) in product(*configurations):
weight_shape[1] = input_shape[1]
attrs = [
{
"strides": strides,
"paddings": paddings,
"groups": groups,
"padding_algorithm": padding_algorithm,
"dilations": dilations,
"data_format": data_format,
},
{"axis": 1},
]
ops_config = [
{
"op_type": "conv2d",
"op_inputs": {
"Input": ["input_data"],
"Filter": ["conv2d_weight"],
},
"op_outputs": {"Output": ["conv_output_data"]},
"op_attrs": attrs[0],
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["conv_output_data"],
"Y": ["elementwise_weight"],
},
"op_outputs": {"Out": ["output_data0"]},
"op_attrs": attrs[1],
},
{
"op_type": act,
"op_inputs": {"X": ["output_data0"]},
"op_outputs": {"Out": ["output_data1"]},
"op_attrs": {},
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"conv2d_weight": TensorConfig(
data_gen=partial(generate_weight, weight_shape)
),
"elementwise_weight": TensorConfig(
data_gen=partial(generate_bias, [weight_shape[0]])
),
},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input1, input_shape)
)
},
outputs=["output_data1"],
)
yield program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=True)
config.enable_use_gpu(256, 0, paddle_infer.PrecisionType.Half)
config.exp_enable_use_cutlass()
yield config, (1e-2, 1e-2)
def test(self, *args, **kwargs):
self.run_test(quant=False, *args, **kwargs)
# cbaa pattern
class TestCutlassConv2dFusionOp2(CutlassAutoScanTest):
def sample_program_configs(self, *args, **kwargs):
def generate_input(input_shape):
return (np.random.random(input_shape) * 2 - 1).astype(np.float32)
def generate_weight(weight_shape):
return (np.random.random(weight_shape) * 2 - 1).astype(np.float32)
def generate_bias(bias_shape):
return np.random.random(bias_shape).astype(np.float32)
input_shape_options = [[1, 16, 112, 112], [1, 24, 64, 64]]
weight_shape_options = [[24, -1, 3, 3]]
strides_options = [[2, 2], [1, 1]]
paddings_options = [[1, 1]]
groups_options = [1]
padding_algorithm_options = ['EXPLICIT']
dilations_options = [[1, 1]]
data_format_options = ['NCHW']
act_options = ['relu']
configurations = [
input_shape_options,
weight_shape_options,
strides_options,
paddings_options,
groups_options,
padding_algorithm_options,
dilations_options,
data_format_options,
act_options,
]
for (
input_shape,
weight_shape,
strides,
paddings,
groups,
padding_algorithm,
dilations,
data_format,
act,
) in product(*configurations):
weight_shape[1] = input_shape[1]
residual_shape = list(input_shape)
residual_shape[1] = weight_shape[0]
ih = input_shape[2]
iw = input_shape[3]
pad_h0 = 0
pad_h1 = 0
pad_w0 = 0
pad_w1 = 0
if len(paddings) == 2:
pad_h0 = paddings[0]
pad_h1 = paddings[0]
pad_w0 = paddings[1]
pad_w1 = paddings[1]
elif len(paddings) == 4:
pad_h0 = paddings[0]
pad_h1 = paddings[1]
pad_w0 = paddings[2]
pad_w1 = paddings[3]
dilation_h = dilations[0]
dilation_w = dilations[1]
kh = weight_shape[2]
kw = weight_shape[3]
stride_h = strides[0]
stride_w = strides[1]
residual_shape[2] = (int)(
(ih + pad_h0 + pad_h1 - dilation_h * (kh - 1) - 1) / stride_h
) + 1
residual_shape[3] = (int)(
(iw + pad_w0 + pad_w1 - dilation_w * (kw - 1) - 1) / stride_w
) + 1
attrs = [
{
"strides": strides,
"paddings": paddings,
"groups": groups,
"padding_algorithm": padding_algorithm,
"dilations": dilations,
"data_format": data_format,
},
{"axis": 1},
]
ops_config = [
{
"op_type": "conv2d",
"op_inputs": {
"Input": ["input_data"],
"Filter": ["conv2d_weight"],
},
"op_outputs": {"Output": ["conv_output_data"]},
"op_attrs": attrs[0],
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["conv_output_data"],
"Y": ["elementwise_weight"],
},
"op_outputs": {"Out": ["output_data0"]},
"op_attrs": attrs[1],
},
{
"op_type": "elementwise_add",
"op_inputs": {
"X": ["residual_data"],
"Y": ["output_data0"],
},
"op_outputs": {"Out": ["output_data1"]},
"op_attrs": {},
},
{
"op_type": act,
"op_inputs": {"X": ["output_data1"]},
"op_outputs": {"Out": ["output_data2"]},
"op_attrs": {},
},
]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(
ops=ops,
weights={
"conv2d_weight": TensorConfig(
data_gen=partial(generate_weight, weight_shape)
),
"elementwise_weight": TensorConfig(
data_gen=partial(generate_bias, [weight_shape[0]])
),
},
inputs={
"input_data": TensorConfig(
data_gen=partial(generate_input, input_shape)
),
"residual_data": TensorConfig(
data_gen=partial(generate_input, residual_shape)
),
},
outputs=["output_data2"],
)
yield program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_gpu=True)
config.enable_use_gpu(256, 0, paddle_infer.PrecisionType.Half)
config.exp_enable_use_cutlass()
yield config, (1e-2, 1e-2)
def test(self, *args, **kwargs):
self.run_test(quant=False, *args, **kwargs)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册