提交 72403e89 编写于 作者: M Megvii Engine Team

perf(cuda): speedup chanwise conv with small feature map and large filter size

GitOrigin-RevId: e65b2ce85626730ed772bf49acfc45bc812ce166
上级 28d48f2f
......@@ -19,6 +19,7 @@ using namespace cuda;
ConvBiasForwardImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(&chanwise);
non_cudnn_algos.push_back(&chanwise_small);
non_cudnn_algos.push_back(&depthwise_large_filter);
non_cudnn_algos.push_back(&inplace_matmul);
non_cudnn_algos.push_back(&matmul);
......@@ -34,6 +35,7 @@ ConvBiasForwardImpl::AlgoPack::AlgoPack() {
std::vector<AlgoBase*> conv_algos;
conv_algos.push_back(&chanwise);
conv_algos.push_back(&chanwise_small);
conv_algos.push_back(&depthwise_large_filter);
conv_algos.push_back(&chanwise8x8x32);
for (auto&& algo : cudnn_convs) {
conv_algos.push_back(&algo);
......
......@@ -22,7 +22,6 @@
#include "src/cuda/conv_bias/opr_impl.h"
#include "src/cuda/convolution_helper/parameter.cuh"
#include "src/cuda/cudnn_wrapper.h"
#include "src/cuda/handle.h"
#include <cuda.h>
#include <memory>
......@@ -57,6 +56,7 @@ public:
CUDA_CUDNN_CONVBIAS,
CUDA_CHANWISE,
CUDA_CHANWISE_SMALL,
CUDA_DEPTHWISE_LARGE_FILTER,
CUDA_CHANWISE_INT8X8X32,
CUDA_CUDNN_CONV,
CUDA_INPLACE_MATMUL,
......@@ -257,6 +257,26 @@ private:
mutable std::string m_name;
};
class ConvBiasForwardImpl::AlgoDepthwiseLargeFilter final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override {
if (m_name.empty()) {
m_name = ConvBiasForward::algo_name<DirectParam>(
"DEPTHWISE_LARGE_FILTER", {});
}
return m_name.c_str();
}
MEGDNN_DECL_ALGO_TYPE(CUDA_DEPTHWISE_LARGE_FILTER)
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
private:
mutable std::string m_name;
};
class ConvBiasForwardImpl::AlgoChanwise8x8x32 final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
......@@ -1084,6 +1104,7 @@ public:
AlgoFallbackNCHWQS8 fallback_nchw_qs8;
AlgoChanwise chanwise;
AlgoChanwiseSmall chanwise_small;
AlgoDepthwiseLargeFilter depthwise_large_filter;
AlgoChanwise8x8x32 chanwise8x8x32;
AlgoInplaceMatmul inplace_matmul;
AlgoMatmul matmul;
......
/**
* \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "cuda.h"
#include "cuda_fp16.h"
// #include "src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cuh"
#include "src/cuda/conv_bias/chanwise/kern.cuh"
#include "src/cuda/conv_bias/chanwise/kern_helper.cuh"
#include "src/cuda/conv_bias/chanwise/launch_config.cuh"
#include "src/cuda/fp16_help.cuh"
using namespace megdnn;
using namespace cuda;
using namespace conv_bias;
using namespace chanwise;
#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl"
namespace megdnn {
namespace cuda {
namespace conv_bias {
namespace chanwise {
// =====================================fwd=====================================
#define check
template <>
void run_fwd_depthwise_large_filter(
float* dst, const float* src, const float* flt, const Param& param,
cudaStream_t stream) {
INSTANCE(DepthwiseConv2dDirection::DIRECTION_FORWARD)
}
} // namespace chanwise
} // namespace conv_bias
} // namespace cuda
} // namespace megdnn
// vim: syntax=cuda.doxygen
......@@ -61,6 +61,10 @@ template <typename T>
void run_fwd_small(
T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream);
template <typename T>
void run_fwd_depthwise_large_filter(
T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream);
// implemented in fwd_8x8x32.cu
void run_fwd_8x8x32(
int32_t* dst, const int8_t* src, const int8_t* flt, const Param& param,
......
/**
* \file dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/conv_bias/chanwise/kern.cuh"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace conv_bias;
namespace {
inline bool is_available_depthwise_large_filter(const chanwise::Param& param) {
auto&& device_prop = cuda::current_device_prop();
int flt_smem_w = (param.flt_w + 3) / 4 * 4;
int flt_smem_h = 3;
int flt_reg_per_thread =
flt_smem_w > 32 ? (flt_smem_w + 31) / 32 : 1 + flt_smem_w / 4;
int ow = param.out_w > 64 ? 64 : param.out_w;
int src_smem_w = ow + flt_smem_w - 1;
int src_smem_h = flt_smem_h + param.flt_h - 1;
int src_reg_per_thread = src_smem_w > 128 ? (flt_smem_w + 127) / 128
: 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1;
int out_reg_per_thread = (ow + 3) / 4 * 4;
if (device_prop.regsPerBlock < 4 * 32 *
(flt_reg_per_thread + src_reg_per_thread +
out_reg_per_thread) ||
device_prop.sharedMemPerBlock <
static_cast<size_t>(
flt_smem_w * flt_smem_h + src_smem_w * src_smem_h)) {
return false;
}
return param.stride_h == 1 && param.stride_w == 1 && param.src_h == param.out_h &&
param.src_w == param.out_w;
}
} // anonymous namespace
bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available(
const SizeArgs& args) const {
if (!args.src_layout->is_contiguous() || !args.dst_layout->is_contiguous()) {
return false;
}
if (args.src_layout->dtype != args.filter_layout->dtype &&
args.src_layout->dtype != dtype::Float32()) {
return false;
}
if (args.z_layout->ndim > 0)
return false;
auto param = chanwise::Param::from_fwd_args(args);
auto&& fm = args.filter_meta;
return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW &&
args.src_layout->dtype.category() == DTypeCategory::FLOAT &&
args.opr->param().compute_mode == Param::ComputeMode::DEFAULT &&
fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && !fm.should_flip &&
is_available_depthwise_large_filter(param);
}
size_t ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::get_workspace_in_bytes(
const SizeArgs& args) const {
auto dst_layout = *args.dst_layout;
if (dst_layout.dtype.enumv() != args.bias_layout->dtype.enumv()) {
dst_layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype, dst_layout.dtype);
return dst_layout.span().dist_byte();
}
return 0;
}
void ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::exec(const ExecArgs& args) const {
WorkspaceBundle bundle{args.workspace.raw_ptr, {get_workspace_in_bytes(args)}};
TensorND conv_dst_tensor = *args.dst_tensor;
if (args.dst_layout->dtype.enumv() != args.bias_layout->dtype.enumv()) {
conv_dst_tensor = TensorND{bundle.get(0), conv_dst_tensor.layout};
conv_dst_tensor.layout.dtype = DType();
args.opr->check_or_deduce_dtype_fwd(
args.src_layout->dtype, args.filter_layout->dtype,
conv_dst_tensor.layout.dtype);
}
{
auto kparam = chanwise::Param::from_fwd_args(args);
auto stream = cuda_stream(args.handle);
switch (args.src_layout->dtype.enumv()) {
case DTypeEnum::Float32:
chanwise::run_fwd_depthwise_large_filter(
conv_dst_tensor.ptr<float>(), args.src_tensor->ptr<float>(),
args.filter_tensor->ptr<float>(), kparam, stream);
break;
default:
megdnn_assert_internal(0);
}
}
handle_bias_and_nonlinear(
args.handle, args.nonlinear_mode, &conv_dst_tensor, args.dst_tensor,
args.bias_tensor);
}
// vim: syntax=cpp.doxygen
......@@ -45,6 +45,7 @@ public:
class AlgoCUDNNConvBiasActivation;
class AlgoChanwise;
class AlgoChanwiseSmall;
class AlgoDepthwiseLargeFilter;
class AlgoChanwise8x8x32;
class AlgoCUDNNConv;
class AlgoFallbackNCHWQS8;
......
......@@ -695,6 +695,59 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE_SMALL) {
}
}
TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
Checker<ConvBiasForward> checker(handle_cuda());
checker.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBias>(
ConvBiasForward::algo_name<ConvBias::DirectParam>(
"DEPTHWISE_LARGE_FILTER", {})
.c_str()));
auto run = [&checker](size_t n, size_t g, size_t h, size_t fh) {
param::ConvBias cur_param;
cur_param.mode = param::ConvBias::Mode::CROSS_CORRELATION;
cur_param.sparse = ConvBias::Param::Sparse::GROUP;
checker.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(3, dtype::Float32())
.set_dtype(4, dtype::Float32());
cur_param.pad_h = cur_param.pad_w = fh / 2;
cur_param.stride_h = cur_param.stride_w = 1;
checker.set_param(cur_param).execs(
{{n, g, h, h}, {g, 1, 1, fh, fh}, {}, {}, {}});
};
run(4, 8, 32, 5);
run(4, 8, 32, 7);
run(4, 8, 32, 9);
run(4, 8, 32, 11);
run(4, 8, 32, 13);
run(4, 8, 32, 15);
run(4, 8, 32, 17);
run(4, 8, 32, 19);
run(4, 8, 32, 21);
run(4, 8, 32, 23);
run(4, 8, 32, 25);
run(4, 8, 32, 27);
run(4, 8, 32, 29);
run(4, 8, 32, 31);
run(4, 8, 64, 5);
run(4, 8, 64, 7);
run(4, 8, 64, 9);
run(4, 8, 64, 11);
run(4, 8, 64, 13);
run(4, 8, 64, 15);
run(4, 8, 64, 17);
run(4, 8, 64, 19);
run(4, 8, 64, 21);
run(4, 8, 64, 23);
run(4, 8, 64, 25);
run(4, 8, 64, 27);
run(4, 8, 64, 29);
run(4, 8, 64, 31);
run(1, 2, 128, 31);
run(1, 2, 256, 31);
}
TEST_F(CUDA, CONV_BIAS_FORWARD_CHANWISE_8x8x32) {
require_compute_capability(6, 1);
Checker<ConvBiasForward> checker(handle_cuda());
......@@ -1474,6 +1527,69 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_TENSORCORE_INT8) {
run_bench(256, 512, 7, 7, 512, 3, 3, 1, 1, 1000);
run_bench(256, 512, 7, 7, 2048, 1, 1, 1, 1, 1000);
}
TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
require_compute_capability(7, 5);
Benchmarker<ConvBiasForward> bencher(handle_cuda());
bencher.set_display(false);
bencher.set_before_exec_callback(conv_bias::ConvBiasAlgoChecker<ConvBiasForward>(
ConvBiasForward::algo_name<ConvBiasForward::DirectParam>(
"DEPTHWISE_LARGE_FILTER", {})
.c_str()));
ConvBias::Param param;
param.format = ConvBias::Param::Format::NCHW;
using NonlineMode = ConvBias::Param::NonlineMode;
param.nonlineMode = NonlineMode::IDENTITY;
param.sparse = ConvBias::Param::Sparse::GROUP;
auto run_bench = [&](size_t batch, size_t g, size_t hi, size_t wi, size_t fh,
size_t fw, size_t sh, size_t sw, size_t nr_times) {
param.pad_h = fh / 2;
param.pad_w = fw / 2;
param.stride_h = sh;
param.stride_w = sw;
bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32());
bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
TensorShape inp{batch, g, hi, wi}, kern{g, 1, 1, fh, fw}, out{batch, g, ho, wo};
float bandwith = static_cast<float>(
inp.total_nr_elems() + kern.total_nr_elems() +
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
auto time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times;
auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12;
printf("chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, time: "
"%.2fms, "
"perf: %.2f Tops bandwidth: %.2fGB/s.\n",
inp.to_string().c_str(), kern.to_string().c_str(),
out.to_string().c_str(), time_in_ms, ops, bandwith * 4 / time_in_ms);
};
run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10);
run_bench(64, 384, 32, 32, 5, 5, 1, 1, 10);
run_bench(64, 384, 32, 32, 7, 7, 1, 1, 10);
run_bench(64, 384, 32, 32, 9, 9, 1, 1, 10);
run_bench(64, 384, 32, 32, 11, 11, 1, 1, 10);
run_bench(64, 384, 32, 32, 13, 13, 1, 1, 10);
run_bench(64, 384, 32, 32, 15, 15, 1, 1, 10);
run_bench(64, 384, 32, 32, 17, 17, 1, 1, 10);
run_bench(64, 384, 32, 32, 19, 19, 1, 1, 10);
run_bench(64, 384, 32, 32, 21, 21, 1, 1, 10);
run_bench(64, 384, 32, 32, 23, 23, 1, 1, 10);
run_bench(64, 384, 32, 32, 25, 25, 1, 1, 10);
run_bench(64, 384, 32, 32, 27, 27, 1, 1, 10);
run_bench(64, 384, 32, 32, 29, 29, 1, 1, 10);
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10);
}
#endif
#endif
......
......@@ -901,6 +901,43 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) {
run(32, 64, 64, 56, 56, 1, 1, 0);
}
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_CHANWISE_SMALL_FEAT_LARGE_FILTER) {
CUBenchmarker<ConvolutionBackwardData> bench{handle_cuda()};
std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{
new OprProxy<ConvolutionBackwardData>{true}};
size_t RUNS = 10;
bench.set_proxy(proxy).set_times(RUNS);
auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH,
size_t SH, size_t PH) {
bench.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32());
param::Convolution param;
param.stride_h = param.stride_w = SH;
param.pad_h = param.pad_w = FH / 2;
param.sparse = param::Convolution::Sparse::GROUP;
bench.set_param(param);
bench.proxy()->target_execution_policy.algo.reset();
TensorLayout src{{N, g, IH, IW}, dtype::Float32()},
filter{{g, 1, 1, FH, FH}, dtype::Float32()};
TensorLayout dst;
{
auto&& opr = handle_cuda()->create_operator<Convolution>();
opr->param() = param;
opr->deduce_layout(src, filter, dst);
}
auto time_ms_fp32 = bench.execl({filter, dst, src}) / RUNS;
float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH;
printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(),
filter.to_string().c_str(), dst.to_string().c_str());
printf("time_fp32=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp32,
(flo / (time_ms_fp32 * 1e9)));
};
run(64, 384, 384, 32, 32, 31, 1, 15);
}
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_BF16) {
CUBenchmarker<ConvolutionBackwardData> bench{handle_cuda()};
std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{
......@@ -1065,6 +1102,46 @@ TEST_F(CUDA, CONVOLUTION_BWD_FILTER_BENCHMARK) {
run(32, 512, 1024, 14, 14, 1, 2, 0);
run(32, 64, 64, 56, 56, 1, 1, 0);
}
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_FILTER_CHANWISE_SMALL_FEAT_LARGE_FILTER) {
CUBenchmarker<ConvolutionBackwardFilter> bench{handle_cuda()};
std::unique_ptr<OprProxy<ConvolutionBackwardFilter>> proxy{
new OprProxy<ConvolutionBackwardFilter>{true}};
size_t RUNS = 10;
bench.set_proxy(proxy).set_times(RUNS);
bench.set_before_exec_callback(AlgoChecker<ConvolutionBackwardFilter>(
"CUDNN_CONVOLUTION_BWD_FILTER_ALGO_FFTv7.6.3"));
auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH,
size_t SH, size_t PH) {
bench.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32());
param::Convolution param;
param.stride_h = param.stride_w = SH;
param.pad_h = param.pad_w = FH / 2;
param.sparse = param::Convolution::Sparse::GROUP;
bench.set_param(param);
bench.proxy()->target_execution_policy.algo.reset();
TensorLayout src{{N, g, IH, IW}, dtype::Float32()},
filter{{g, 1, 1, FH, FH}, dtype::Float32()};
TensorLayout dst;
{
auto&& opr = handle_cuda()->create_operator<Convolution>();
opr->param() = param;
opr->deduce_layout(src, filter, dst);
}
auto time_ms_fp32 = bench.execl({src, dst, filter}) / RUNS;
float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH;
printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(),
filter.to_string().c_str(), dst.to_string().c_str());
printf("time_fp32=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp32,
(flo / (time_ms_fp32 * 1e9)));
};
run(64, 384, 384, 32, 32, 31, 1, 15);
}
#endif
#undef CUDNN_VERSION_STRING
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册