提交 6b8a69d5 编写于 作者: M Megvii Engine Team

feat(cuda): float16 depthwise large kernel conv compute fp32

GitOrigin-RevId: 3050d48f2691faeeda4fb054134041cc620b5a35
上级 bc385b53
......@@ -11,7 +11,6 @@
#include "cuda.h"
#include "cuda_fp16.h"
// #include "src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cuh"
#include "src/cuda/conv_bias/chanwise/kern.cuh"
#include "src/cuda/conv_bias/chanwise/kern_helper.cuh"
#include "src/cuda/conv_bias/chanwise/launch_config.cuh"
......
......@@ -32,15 +32,14 @@ inline bool is_available_depthwise_large_filter(const chanwise::Param& param) {
: 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1;
int out_reg_per_thread = (ow + 3) / 4 * 4;
if (device_prop.regsPerBlock < 4 * 32 *
(flt_reg_per_thread + src_reg_per_thread +
out_reg_per_thread) ||
(flt_reg_per_thread * 2 +
src_reg_per_thread + out_reg_per_thread) ||
device_prop.sharedMemPerBlock <
static_cast<size_t>(
flt_smem_w * flt_smem_h + src_smem_w * src_smem_h)) {
flt_smem_w * flt_smem_h * 2 + src_smem_w * src_smem_h)) {
return false;
}
return param.stride_h == 1 && param.stride_w == 1 && param.src_h == param.out_h &&
param.src_w == param.out_w;
return true;
}
} // anonymous namespace
......
......@@ -68,7 +68,7 @@ public:
const TensorLayout& grad);
convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, grad_layout, filter_layout, filter_meta, diff_layout};
return {handle, diff_layout, filter_layout, filter_meta, grad_layout};
}
};
struct ExecArgs : public SizeArgs {
......
......@@ -31,15 +31,17 @@ inline bool is_available_depthwise_large_filter(const chanwise::Param& param) {
: 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1;
int out_reg_per_thread = (ow + 3) / 4 * 4;
if (device_prop.regsPerBlock < 4 * 32 *
(flt_reg_per_thread + src_reg_per_thread +
out_reg_per_thread) ||
(flt_reg_per_thread * 2 +
src_reg_per_thread + out_reg_per_thread) ||
device_prop.sharedMemPerBlock <
static_cast<size_t>(
flt_smem_w * flt_smem_h + src_smem_w * src_smem_h)) {
flt_smem_w * flt_smem_h * 2 + src_smem_w * src_smem_h)) {
return false;
}
return param.stride_h == 1 && param.stride_w == 1 && param.src_h == param.out_h &&
param.src_w == param.out_w;
printf("param.src_w = %d, param.src_h = %d, param.out_w = %d, param.out_h = %d\n",
param.src_w, param.src_h, param.out_w, param.out_h);
return (param.stride_h == 1 && param.stride_w == 1) ||
(param.stride_h == 2 && param.stride_w == 2);
}
} // anonymous namespace
......
......@@ -45,6 +45,12 @@ fma2(const __half2 a, const __half2 b, const __half2 c) {
#endif
}
__device__ __forceinline__ float2
fma2(const __half2 a, const __half2 b, const float2 c) {
return {__half2float(a.x) * __half2float(b.x) + c.x,
__half2float(a.y) * __half2float(b.y) + c.y};
}
#endif // CUDA_VERSION >= 9000
} // namespace cuda
......
......@@ -701,8 +701,10 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
ConvBiasForward::algo_name<ConvBias::DirectParam>(
"DEPTHWISE_LARGE_FILTER", {})
.c_str()));
for (auto dtype : std::vector<DType>{dtype::Float16()}) {
auto run = [&checker, &dtype](size_t n, size_t g, size_t h, size_t fh) {
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
auto run = [&checker, &dtype](
size_t n, size_t g, size_t h, size_t fh, size_t padding,
size_t stride) {
param::ConvBias cur_param;
cur_param.mode = param::ConvBias::Mode::CROSS_CORRELATION;
cur_param.sparse = ConvBias::Param::Sparse::GROUP;
......@@ -711,42 +713,52 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
.set_dtype(2, dtype)
.set_dtype(3, dtype)
.set_dtype(4, dtype);
float scale = 64.f / sqrt(fh * fh);
UniformFloatRNG rng(scale, 2 * scale);
checker.set_rng(0, &rng)
.set_rng(1, &rng)
.set_rng(2, &rng)
.set_rng(3, &rng)
.set_rng(4, &rng);
if (dtype.enumv() == DTypeEnum::Float16) {
checker.set_epsilon(1e-1);
}
cur_param.pad_h = cur_param.pad_w = fh / 2;
cur_param.stride_h = cur_param.stride_w = 1;
cur_param.pad_h = cur_param.pad_w = padding;
cur_param.stride_h = cur_param.stride_w = stride;
checker.set_param(cur_param).execs(
{{n, g, h, h}, {g, 1, 1, fh, fh}, {}, {}, {}});
};
run(4, 8, 32, 5);
run(4, 8, 32, 7);
run(4, 8, 32, 9);
run(4, 8, 32, 11);
run(4, 8, 32, 13);
run(4, 8, 32, 15);
run(4, 8, 32, 17);
run(4, 8, 32, 19);
run(4, 8, 32, 21);
run(4, 8, 32, 23);
run(4, 8, 32, 25);
run(4, 8, 32, 27);
run(4, 8, 32, 29);
run(4, 8, 32, 31);
run(4, 8, 64, 5);
run(4, 8, 64, 7);
run(4, 8, 64, 9);
run(4, 8, 64, 11);
run(4, 8, 64, 13);
run(4, 8, 64, 15);
run(4, 8, 64, 17);
run(4, 8, 64, 19);
run(4, 8, 64, 21);
run(4, 8, 64, 23);
run(4, 8, 64, 25);
run(4, 8, 64, 27);
run(4, 8, 64, 29);
run(4, 8, 64, 31);
run(1, 2, 128, 31);
run(1, 2, 256, 31);
run(4, 8, 32, 5, 5 / 2, 1);
run(4, 8, 32, 7, 7 / 2, 1);
run(4, 8, 32, 9, 9 / 2, 1);
run(4, 8, 32, 11, 11 / 2, 1);
run(4, 8, 32, 13, 13 / 2, 1);
run(4, 8, 32, 15, 15 / 2, 1);
run(4, 8, 32, 17, 17 / 2, 1);
run(4, 8, 32, 19, 19 / 2, 1);
run(4, 8, 32, 21, 21 / 2, 1);
run(4, 8, 32, 23, 23 / 2, 1);
run(4, 8, 32, 25, 25 / 2, 1);
run(4, 8, 32, 27, 27 / 2, 1);
run(4, 8, 32, 29, 29 / 2, 1);
run(4, 8, 32, 31, 31 / 2, 1);
run(4, 8, 64, 5, 5 / 3, 2);
run(4, 8, 64, 7, 7 / 3, 2);
run(4, 8, 64, 9, 9 / 3, 2);
run(4, 8, 64, 11, 11 / 3, 2);
run(4, 8, 64, 13, 13 / 3, 2);
run(4, 8, 64, 15, 15 / 3, 2);
run(4, 8, 64, 17, 17 / 3, 2);
run(4, 8, 64, 19, 19 / 3, 2);
run(4, 8, 64, 21, 21 / 3, 2);
run(4, 8, 64, 23, 23 / 3, 2);
run(4, 8, 64, 25, 25 / 3, 2);
run(4, 8, 64, 27, 27 / 3, 2);
run(4, 8, 64, 29, 29 / 3, 2);
run(4, 8, 64, 31, 31 / 3, 2);
run(1, 2, 128, 31, 10, 2);
run(1, 2, 256, 31, 10, 2);
}
}
......@@ -1530,7 +1542,7 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_TENSORCORE_INT8) {
run_bench(256, 512, 7, 7, 2048, 1, 1, 1, 1, 1000);
}
TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP16) {
require_compute_capability(7, 5);
Benchmarker<ConvBiasForward> bencher(handle_cuda());
bencher.set_display(false);
......@@ -1552,6 +1564,11 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
param.stride_h = sh;
param.stride_w = sw;
bencher.set_param(param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_dtype(4, dtype::Float16());
bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
......@@ -1562,25 +1579,13 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
out.total_nr_elems()) /
(1024 * 1024 * 1024) * 1e3;
bencher.set_param(param)
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32());
auto fp32_time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times;
bencher.set_param(param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_dtype(4, dtype::Float16());
auto fp16_time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times;
printf("chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, fp32_time: "
"%.2fms, fp16_time: %.2fms, speedup: %0.2f (fp16/fp32) "
"fp32_bandwidth: %.2fGB/s fp16_bandwidth: %.2fGB/s.\n",
auto time_in_ms = bencher.execs({inp, kern, {}, {}, out}) / nr_times;
auto ops = 2.0 * batch * g * ho * wo * fh * fw / (time_in_ms * 1e-3) * 1e-12;
printf("chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, time: "
"%.2fms, "
"perf: %.2f Tops bandwidth: %.2fGB/s.\n",
inp.to_string().c_str(), kern.to_string().c_str(),
out.to_string().c_str(), fp32_time_in_ms, fp16_time_in_ms,
fp32_time_in_ms / fp16_time_in_ms, bandwith * 4 / fp32_time_in_ms,
bandwith * 2 / fp16_time_in_ms);
out.to_string().c_str(), time_in_ms, ops, bandwith * 4 / time_in_ms);
};
run_bench(64, 384, 32, 32, 3, 3, 1, 1, 10);
......@@ -1600,7 +1605,7 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
run_bench(64, 384, 32, 32, 31, 31, 1, 1, 10);
}
TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP16) {
TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP32) {
require_compute_capability(7, 5);
Benchmarker<ConvBiasForward> bencher(handle_cuda());
bencher.set_display(false);
......@@ -1623,10 +1628,10 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP16) {
param.stride_w = sw;
bencher.set_param(param)
.set_dtype(0, dtype::Float16())
.set_dtype(1, dtype::Float16())
.set_dtype(2, dtype::Float16())
.set_dtype(4, dtype::Float16());
.set_dtype(0, dtype::Float32())
.set_dtype(1, dtype::Float32())
.set_dtype(2, dtype::Float32())
.set_dtype(4, dtype::Float32());
bencher.set_times(nr_times);
size_t ho = infer_conv_shape(hi, fh, sh, param.pad_h);
size_t wo = infer_conv_shape(wi, fw, sw, param.pad_w);
......
......@@ -728,48 +728,58 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) {
Checker<ConvolutionBackwardData> checker(handle_cuda());
checker.set_before_exec_callback(
AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER"));
for (auto dtype : std::vector<DType>{dtype::Float16()}) {
auto run = [&checker, &dtype](size_t n, size_t g, size_t h, size_t fh) {
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) {
auto run = [&checker, &dtype](
size_t n, size_t g, size_t h, size_t fh, size_t padding,
size_t stride) {
param::Convolution param;
param.stride_h = param.stride_w = 1;
param.pad_h = param.pad_w = fh / 2;
param.stride_h = param.stride_w = stride;
param.pad_h = param.pad_w = padding;
param.mode = Convolution::Mode::CROSS_CORRELATION;
param.sparse = param::Convolution::Sparse::GROUP;
checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
float scale = 64.f / sqrt(fh * fh);
UniformFloatRNG rng(1.0, 1.0);
checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &rng);
if (dtype.enumv() == DTypeEnum::Float16)
checker.set_epsilon(1e-1);
checker.set_param(param).execs(
{{g, 1, 1, fh, fh}, {n, g, h, h}, {n, g, h, h}});
{{g, 1, 1, fh, fh},
{n, g, (h + 2 * padding - fh + 1) / stride,
(h + 2 * padding - fh + 1) / stride},
{n, g, h, h}});
};
run(4, 8, 32, 5);
run(4, 8, 32, 7);
run(4, 8, 32, 9);
run(4, 8, 32, 11);
run(4, 8, 32, 13);
run(4, 8, 32, 15);
run(4, 8, 32, 17);
run(4, 8, 32, 19);
run(4, 8, 32, 21);
run(4, 8, 32, 23);
run(4, 8, 32, 25);
run(4, 8, 32, 27);
run(4, 8, 32, 29);
run(4, 8, 32, 31);
run(4, 8, 64, 7);
run(4, 8, 64, 5);
run(4, 8, 64, 9);
run(4, 8, 64, 11);
run(4, 8, 64, 13);
run(4, 8, 64, 15);
run(4, 8, 64, 17);
run(4, 8, 64, 19);
run(4, 8, 64, 21);
run(4, 8, 64, 23);
run(4, 8, 64, 25);
run(4, 8, 64, 27);
run(4, 8, 64, 29);
run(4, 8, 64, 31);
run(1, 2, 128, 31);
run(1, 2, 256, 31);
run(4, 8, 32, 5, 5 / 2, 1);
run(4, 8, 32, 7, 7/2, 1);
run(4, 8, 32, 9, 9/2, 1);
run(4, 8, 32, 11, 11/2, 1);
run(4, 8, 32, 13, 13/2, 1);
run(4, 8, 32, 15, 15/2, 1);
run(4, 8, 32, 17, 17/2, 1);
run(4, 8, 32, 19, 19/2, 1);
run(4, 8, 32, 21, 21/2, 1);
run(4, 8, 32, 23, 23/2, 1);
run(4, 8, 32, 25, 25/2, 1);
run(4, 8, 32, 27, 27/2, 1);
run(4, 8, 32, 29, 29/2, 1);
run(4, 8, 32, 31, 31/2, 1);
run(4, 8, 64, 5, 5 / 2, 2);
run(4, 8, 64, 7, 7/3, 2);
run(4, 8, 64, 9, 9/3, 2);
run(4, 8, 64, 11, 11/3, 2);
run(4, 8, 64, 13, 13/3, 2);
run(4, 8, 64, 15, 15/3, 2);
run(4, 8, 64, 17, 17/3, 2);
run(4, 8, 64, 19, 19/3, 2);
run(4, 8, 64, 21, 21/3, 2);
run(4, 8, 64, 23, 23/3, 2);
run(4, 8, 64, 25, 25/3, 2);
run(4, 8, 64, 27, 27/3, 2);
run(4, 8, 64, 29, 29/3, 2);
run(4, 8, 64, 31, 31/3, 2);
run(1, 2, 128, 31, 31/3, 2);
run(1, 2, 256, 31, 31/3, 2);
}
}
......@@ -950,7 +960,7 @@ TEST_F(CUDA, CONVOLUTION_BWD_DATA_BENCHMARK) {
run(32, 64, 64, 56, 56, 1, 1, 0);
}
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER) {
TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER_FP32) {
CUBenchmarker<ConvolutionBackwardData> bencher{handle_cuda()};
bencher.set_display(false);
bencher.set_before_exec_callback(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册