提交 7d2063e3 编写于 作者: M Megvii Engine Team

perf(cuda): speedup conv backward data with small feature map and large filter size

GitOrigin-RevId: 85592bca6b06239a3aa30621a386738553597023
上级 72403e89
......@@ -19,10 +19,12 @@ using namespace cuda;
ConvolutionBackwardDataImpl::AlgoPack::AlgoPack() {
non_cudnn_algos.push_back(&chanwise);
non_cudnn_algos.push_back(&chanwise_small);
non_cudnn_algos.push_back(&depthwise_large_filter);
non_cudnn_algos.push_back(&matmul);
all_algos.push_back(&chanwise); // prefer chanwise
all_algos.push_back(&chanwise_small); // prefer small chanwise
all_algos.push_back(&depthwise_large_filter);
fill_cudnn_algos();
for (auto&& i : cudnn) {
......
......@@ -37,6 +37,7 @@ public:
CUDA_MATMUL,
CUDA_CHANWISE,
CUDA_CHANWISE_SMALL,
CUDA_DEPTHWISE_LARGE_FILTER,
CUDA_BFLOAT16,
CUDA_GROUP_CONV_GENERAL,
CUDA_IMPLICIT_GEMM_NCHW4_DOTPROD_INT8,
......@@ -192,6 +193,20 @@ public:
}
};
class ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
size_t get_workspace_in_bytes(const SizeArgs& args) const override;
void exec(const ExecArgs& args) const override;
const char* name() const override { return "DEPTHWISE_LARGE_FILTER"; }
MEGDNN_DECL_ALGO_TYPE(CUDA_DEPTHWISE_LARGE_FILTER)
AlgoAttribute attribute() const override { return AlgoAttribute::REPRODUCIBLE; }
private:
mutable std::string m_name;
};
class ConvolutionBackwardDataImpl::AlgoBFloat16 final : public AlgoBase {
public:
bool is_available(const SizeArgs& args) const override;
......@@ -411,6 +426,7 @@ public:
AlgoMatmul matmul;
AlgoChanwise chanwise;
AlgoChanwiseSmall chanwise_small;
AlgoDepthwiseLargeFilter depthwise_large_filter;
AlgoBFloat16 bfloat16;
AlgoGroupConvGeneral group;
std::vector<AlgoInt8NCHW4DotProdImplicitGemm> int8_nchw4_dotprod;
......
/**
* \file dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/convolution/backward_data/algo.h"
#include "src/cuda/convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"
using namespace megdnn;
using namespace cuda;
using namespace convolution;
namespace {
inline bool is_available_depthwise_large_filter(const chanwise::Param& param) {
auto&& device_prop = cuda::current_device_prop();
int flt_smem_w = (param.flt_w + 3) / 4 * 4;
int flt_smem_h = 3;
int flt_reg_per_thread =
flt_smem_w > 32 ? (flt_smem_w + 31) / 32 : 1 + flt_smem_w / 4;
int ow = param.out_w > 64 ? 64 : param.out_w;
int src_smem_w = ow + flt_smem_w - 1;
int src_smem_h = flt_smem_h + param.flt_h - 1;
int src_reg_per_thread = src_smem_w > 128 ? (flt_smem_w + 127) / 128
: 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1;
int out_reg_per_thread = (ow + 3) / 4 * 4;
if (device_prop.regsPerBlock < 4 * 32 *
(flt_reg_per_thread + src_reg_per_thread +
out_reg_per_thread) ||
device_prop.sharedMemPerBlock <
static_cast<size_t>(
flt_smem_w * flt_smem_h + src_smem_w * src_smem_h)) {
return false;
}
return param.stride_h == 1 && param.stride_w == 1 && param.src_h == param.out_h &&
param.src_w == param.out_w;
}
} // anonymous namespace
bool ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::is_available(
const SizeArgs& args) const {
if (!args.grad_layout->is_contiguous() || !args.diff_layout->is_contiguous()) {
return false;
}
if (args.diff_layout->dtype != args.filter_layout->dtype &&
args.diff_layout->dtype != dtype::Float32()) {
return false;
}
auto param = chanwise::Param::from_fwd_args(args.as_fwd_args());
auto&& fm = args.filter_meta;
return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&
args.opr->param().compute_mode == Param::ComputeMode::DEFAULT &&
fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && !fm.should_flip &&
is_available_depthwise_large_filter(param);
}
size_t ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::get_workspace_in_bytes(
const SizeArgs& args) const {
return 0;
}
void ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::exec(
const ExecArgs& args) const {
auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args());
auto stream = cuda_stream(args.handle);
switch (args.diff_layout->dtype.enumv()) {
case DTypeEnum::Float32:
chanwise::run_bwd_depthwise_large_filter(
args.grad_tensor->ptr<float>(), args.diff_tensor->ptr<float>(),
args.filter_tensor->ptr<float>(), kparam, stream);
break;
default:
megdnn_assert_internal(0);
}
}
// vim: syntax=cpp.doxygen
/**
* \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./kern.cuh"
#include "./kern_helper.cuh"
#include "cuda.h"
#include "cuda_fp16.h"
#include "src/cuda/convolution/chanwise/launch_config.cuh"
#include "src/cuda/fp16_help.cuh"
using namespace megdnn;
using namespace cuda;
using namespace convolution;
using namespace chanwise;
#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl"
namespace megdnn {
namespace cuda {
namespace convolution {
namespace chanwise {
// =====================================fwd=====================================
template <>
void run_bwd_depthwise_large_filter(
float* dst, const float* src, const float* flt, const Param& param,
cudaStream_t stream) {
INSTANCE(DepthwiseConv2dDirection::DIRECTION_BACKWARD)
}
} // namespace chanwise
} // namespace convolution
} // namespace cuda
} // namespace megdnn
// vim: syntax=cuda.doxygen
......@@ -63,6 +63,10 @@ void run_bwd_data(
T* src_grad, const T* dst_grad, const T* flt, const Param& param,
cudaStream_t stream);
template <typename T>
void run_bwd_depthwise_large_filter(
T* dst, const T* src, const T* flt, const Param& param, cudaStream_t stream);
template <typename T>
void run_bwd_filter(
T* filter_grad, const T* src, const T* dst_grad, const Param& param,
......
......@@ -97,6 +97,7 @@ public:
class AlgoMatmul;
class AlgoChanwise;
class AlgoChanwiseSmall;
class AlgoDepthwiseLargeFilter;
class AlgoGroupConvGeneral;
class AlgoBFloat16;
class AlgoInt8NCHW4DotProdImplicitGemm;
......
......@@ -724,6 +724,55 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DATA_1) {
TensorLayoutArray{filter, dst, src});
}
TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) {
Checker<ConvolutionBackwardData> checker(handle_cuda());
checker.set_before_exec_callback(
AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER"));
for (auto dtype : std::vector<DType>{dtype::Float32()}) {
auto run = [&checker, &dtype](size_t n, size_t g, size_t h, size_t fh) {
param::Convolution param;
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = fh / 2;
param.mode = Convolution::Mode::CROSS_CORRELATION;
param.sparse = param::Convolution::Sparse::GROUP;
checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
checker.set_param(param).execs(
{{g, 1, 1, fh, fh}, {n, g, h, h}, {n, g, h, h}});
};
run(4, 8, 32, 5);
run(4, 8, 32, 7);
run(4, 8, 32, 9);
run(4, 8, 32, 11);
run(4, 8, 32, 13);
run(4, 8, 32, 15);
run(4, 8, 32, 17);
run(4, 8, 32, 19);
run(4, 8, 32, 21);
run(4, 8, 32, 23);
run(4, 8, 32, 25);
run(4, 8, 32, 27);
run(4, 8, 32, 29);
run(4, 8, 32, 31);
run(4, 8, 64, 7);
run(4, 8, 64, 5);
run(4, 8, 64, 9);
run(4, 8, 64, 11);
run(4, 8, 64, 13);
run(4, 8, 64, 15);
run(4, 8, 64, 17);
run(4, 8, 64, 19);
run(4, 8, 64, 21);
run(4, 8, 64, 23);
run(4, 8, 64, 25);
run(4, 8, 64, 27);
run(4, 8, 64, 29);
run(4, 8, 64, 31);
run(1, 2, 128, 31);
run(1, 2, 256, 31);
}
}
#if MEGDNN_WITH_BENCHMARK
TEST_F(CUDA, CONV_FWD_BENCHMARK) {
auto run = [&](size_t N, size_t OC, size_t IC, size_t IH, size_t IW, size_t SH = 1,
......@@ -901,24 +950,23 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) {
run(32, 64, 64, 56, 56, 1, 1, 0);
}
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_CHANWISE_SMALL_FEAT_LARGE_FILTER) {
CUBenchmarker<ConvolutionBackwardData> bench{handle_cuda()};
std::unique_ptr<OprProxy<ConvolutionBackwardData>> proxy{
new OprProxy<ConvolutionBackwardData>{true}};
size_t RUNS = 10;
bench.set_proxy(proxy).set_times(RUNS);
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER) {
CUBenchmarker<ConvolutionBackwardData> bencher{handle_cuda()};
bencher.set_display(false);
bencher.set_before_exec_callback(
AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER"));
auto run = [&](size_t N, size_t OC, size_t g, size_t IH, size_t IW, size_t FH,
size_t SH, size_t PH) {
bench.set_dtype(0, dtype::Float32())
size_t SH, size_t nr_times) {
bencher.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32());
param::Convolution param;
param.stride_h = param.stride_w = SH;
param.pad_h = param.pad_w = FH / 2;
param.sparse = param::Convolution::Sparse::GROUP;
bench.set_param(param);
bench.proxy()->target_execution_policy.algo.reset();
bencher.set_param(param);
bencher.set_times(nr_times);
TensorLayout src{{N, g, IH, IW}, dtype::Float32()},
filter{{g, 1, 1, FH, FH}, dtype::Float32()};
TensorLayout dst;
......@@ -927,15 +975,28 @@ TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_CHANWISE_SMALL_FEAT_LARGE_FILTER) {
opr->param() = param;
opr->deduce_layout(src, filter, dst);
}
auto time_ms_fp32 = bench.execl({filter, dst, src}) / RUNS;
auto time_ms_fp32 = bencher.execl({filter, dst, src}) / nr_times;
float flo = 2.0 * N * g * dst[2] * dst[3] * FH * FH;
printf("inp=%s, kern=%s, dst=%s ", src.to_string().c_str(),
filter.to_string().c_str(), dst.to_string().c_str());
printf("time_fp32=%.2fms, flops=%.3fTFLOPS\n", time_ms_fp32,
(flo / (time_ms_fp32 * 1e9)));
};
run(64, 384, 384, 32, 32, 31, 1, 15);
run(64, 384, 384, 32, 32, 3, 1, 10);
run(64, 384, 384, 32, 32, 5, 1, 10);
run(64, 384, 384, 32, 32, 7, 1, 10);
run(64, 384, 384, 32, 32, 9, 1, 10);
run(64, 384, 384, 32, 32, 11, 1, 10);
run(64, 384, 384, 32, 32, 13, 1, 10);
run(64, 384, 384, 32, 32, 15, 1, 10);
run(64, 384, 384, 32, 32, 17, 1, 10);
run(64, 384, 384, 32, 32, 19, 1, 10);
run(64, 384, 384, 32, 32, 21, 1, 10);
run(64, 384, 384, 32, 32, 23, 1, 10);
run(64, 384, 384, 32, 32, 25, 1, 10);
run(64, 384, 384, 32, 32, 27, 1, 10);
run(64, 384, 384, 32, 32, 29, 1, 10);
run(64, 384, 384, 32, 32, 31, 1, 10);
}
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_BF16) {
......@@ -1103,7 +1164,7 @@ TEST_F(CUDA, CONVOLUTION_BWD_FILTER_BENCHMARK) {
run(32, 64, 64, 56, 56, 1, 1, 0);
}
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_FILTER_CHANWISE_SMALL_FEAT_LARGE_FILTER) {
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_FILTER_DEPTHWISE_LARGE_FILTER) {
CUBenchmarker<ConvolutionBackwardFilter> bench{handle_cuda()};
std::unique_ptr<OprProxy<ConvolutionBackwardFilter>> proxy{
new OprProxy<ConvolutionBackwardFilter>{true}};
......
......@@ -57,6 +57,7 @@
#cmakedefine01 MEGDNN_64_BIT
#cmakedefine01 MEGDNN_THREADS_512
#cmakedefine01 MEGDNN_ENABLE_MULTI_THREADS
#cmakedefine01 MEGDNN_WITH_BENCHMARK
// whether atlas is available
#ifndef MGB_ATLAS
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册