提交 8e5410e4 编写于 作者: M Megvii Engine Team 提交者: 王彪

feat(cuda): add fp16 compute 16 kernel

GitOrigin-RevId: e03435be021ccf3d8eff357a80d5203e903aca96
上级 472e2f96
...@@ -235,7 +235,175 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( ...@@ -235,7 +235,175 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z, int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x; off_oh = threadIdx.y, off_ow = threadIdx.x;
constexpr int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2; constexpr int t2_src_unroll_w = (SrcTileConfig::unroll_w + 3) / 2;
constexpr int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2;
constexpr int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2;
extern __shared__ __align__(8) unsigned char smem[];
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
out_start_h = off_obh * OutTileConfig::block_h,
out_start_w = off_obw * OutTileConfig::block_w,
src_start_h = out_start_h * stride_h - param.pad_h,
src_start_w = out_start_w * stride_w - param.pad_w,
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h;
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w;
T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w;
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w;
SrcGlobal2ShareVisitor gl2sh_src = {
smem_src,
param.src_w,
is_fwd ? src_start_h
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2),
is_fwd ? src_start_w
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2),
is_fwd ? param.src_h : param.src_h * param.stride_h,
is_fwd ? param.src_w : param.src_w * param.stride_w,
is_fwd ? 1 : param.stride_h,
is_fwd ? 1 : param.stride_w};
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt,
param.flt_w,
is_fwd ? 0 : param.flt_h - 2,
0,
param.flt_h,
param.flt_w,
1,
1};
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w;
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w;
gl2sh_src.first_copy();
gl2sh_flt.first_copy();
__syncthreads();
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w],
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w];
T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
gl2sh_flt.copy();
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) {
int src_offset = (off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w * 2;
reg_src[s_h * t2_src_unroll_w + s_w] =
*reinterpret_cast<T2*>(smem_src_ptr + src_offset);
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) {
int flt_offset =
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w +
f_w * 2;
reg_flt[0][f_h * t2_flt_unroll_w + f_w] =
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset);
if (f_w > 0) {
reg_flt[1][f_h * t2_flt_unroll_w + f_w] =
T2{reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y,
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
} else {
reg_flt[1][f_h * t2_flt_unroll_w + f_w] =
T2{0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
}
}
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0};
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] =
T2{reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0};
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[ow * stride_w % 2]
[inner_fh * t2_flt_unroll_w + fw],
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw +
ow * stride_w / 2],
sum[oh * t2_out_unroll_w + ow]);
}
}
}
}
__syncthreads();
gl2sh_src.commit();
gl2sh_flt.commit();
gl2sh_src.iter_forward();
gl2sh_flt.iter_forward();
__syncthreads();
}
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
sum[o] = megdnn::cuda::hadd2(sum[o], __shfl_xor(sum[o], i, 32));
}
}
if (threadIdx.x == 0) {
#pragma unroll
for (int i = 0; i < OutTileConfig::unroll_h; ++i) {
int out_h_idx = out_base_h_idx + i;
if (out_h_idx < param.out_h) {
#pragma unroll
for (int j = 0; j < OutTileConfig::unroll_w; ++j) {
int out_w_idx = out_start_w + j;
if (out_w_idx >= param.out_w)
return;
out_base_ptr[out_h_idx * param.out_w + out_w_idx] = __float2half(
__half2float(sum[i * OutTileConfig::unroll_w + j].x) +
__half2float(sum[i * OutTileConfig::unroll_w + j].y));
}
}
}
}
}
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
__global__ void DepthwiseConv2dGPUKernelNCHWC32(
const Param param, const __half* input, const __half* filter, __half* output) {
using T = __half;
using T2 = __half2;
using ThreadConfig = typename ConvTrait::ThreadConfig;
using SrcTileConfig = typename ConvTrait::SrcTileConfig;
using FilterTileConfig = typename ConvTrait::FilterTileConfig;
using OutTileConfig = typename ConvTrait::OutTileConfig;
using SrcTileCount = typename ConvTrait::SrcTileCount;
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
constexpr int t2_src_unroll_w = (SrcTileConfig::unroll_w + 3) / 2;
constexpr int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2; constexpr int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2;
constexpr int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2; constexpr int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2;
...@@ -320,17 +488,17 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( ...@@ -320,17 +488,17 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
reg_flt[0][f_h * t2_flt_unroll_w + f_w] = reg_flt[0][f_h * t2_flt_unroll_w + f_w] =
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset); *reinterpret_cast<T2*>(smem_flt_ptr + flt_offset);
if (f_w > 0) { if (f_w > 0) {
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = { reg_flt[1][f_h * t2_flt_unroll_w + f_w] =
reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y, T2{reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y,
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
} else { } else {
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = { reg_flt[1][f_h * t2_flt_unroll_w + f_w] =
0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x}; T2{0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
} }
} }
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {0.0, 0.0}; reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = T2{0.0, 0.0};
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = { reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] =
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0}; T2{reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0};
} }
#pragma unroll #pragma unroll
...@@ -535,6 +703,154 @@ __global__ void DepthwiseConv2dGPUKernelNCHW( ...@@ -535,6 +703,154 @@ __global__ void DepthwiseConv2dGPUKernelNCHW(
} }
} }
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
__global__ void DepthwiseConv2dGPUKernelNCHWC32(
const Param param, const float* input, const float* filter, float* output) {
using T = float;
using T2 = float2;
using ThreadConfig = typename ConvTrait::ThreadConfig;
using SrcTileConfig = typename ConvTrait::SrcTileConfig;
using FilterTileConfig = typename ConvTrait::FilterTileConfig;
using OutTileConfig = typename ConvTrait::OutTileConfig;
using SrcTileCount = typename ConvTrait::SrcTileCount;
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
extern __shared__ __align__(8) unsigned char smem[];
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int stride_h = is_fwd ? param.stride_h : 1;
int stride_w = is_fwd ? param.stride_w : 1;
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
out_start_h = off_obh * OutTileConfig::block_h,
out_start_w = off_obw * OutTileConfig::block_w,
src_start_h = out_start_h * stride_h - param.pad_h,
src_start_w = out_start_w * stride_w - param.pad_w,
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h;
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w;
T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w;
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w;
SrcGlobal2ShareVisitor gl2sh_src = {
smem_src,
param.src_w,
is_fwd ? src_start_h
: src_start_h - (param.out_h / 2 + param.flt_h / 2 - param.pad_h -
param.src_h * param.stride_h / 2),
is_fwd ? src_start_w
: src_start_w - (param.out_w / 2 + param.flt_w / 2 - param.pad_w -
param.src_w * param.stride_w / 2),
is_fwd ? param.src_h : param.src_h * param.stride_h,
is_fwd ? param.src_w : param.src_w * param.stride_w,
is_fwd ? 1 : param.stride_h,
is_fwd ? 1 : param.stride_w};
FilterGlobal2ShareVisitor gl2sh_flt = {smem_flt,
param.flt_w,
is_fwd ? 0 : param.flt_h - 2,
0,
param.flt_h,
param.flt_w,
1,
1};
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w;
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w;
gl2sh_src.first_copy();
gl2sh_flt.first_copy();
__syncthreads();
T reg_src[SrcTileConfig::unroll_h * SrcTileConfig::unroll_w],
reg_flt[FilterTileConfig::unroll_h * FilterTileConfig::unroll_w];
T sum[OutTileConfig::unroll_size] = {0.0};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
gl2sh_flt.copy();
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
[(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w];
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
[(fh + f_h) % FilterTileCount::smem_h *
FilterTileCount::smem_w +
f_w];
}
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < FilterTileConfig::unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * OutTileConfig::unroll_w + ow] +=
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w];
}
}
}
}
__syncthreads();
gl2sh_src.commit();
gl2sh_flt.commit();
gl2sh_src.iter_forward();
gl2sh_flt.iter_forward();
__syncthreads();
}
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
sum[o] += __shfl_xor(sum[o], i, 32);
}
}
if (threadIdx.x == 0) {
#pragma unroll
for (int i = 0; i < OutTileConfig::unroll_h; ++i) {
int out_h_idx = out_base_h_idx + i;
if (out_h_idx < param.out_h) {
#pragma unroll
for (int j = 0; j < OutTileConfig::unroll_w; ++j) {
int out_w_idx = out_start_w + j;
if (out_w_idx >= param.out_w)
return;
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
sum[i * OutTileConfig::unroll_w + j];
}
}
}
}
}
template < template <
typename T, typename T2, DepthwiseConv2dDirection kDirection, int unroll_fw, typename T, typename T2, DepthwiseConv2dDirection kDirection, int unroll_fw,
int unroll_ow, int stride> int unroll_ow, int stride>
...@@ -561,7 +877,12 @@ void LaunchDepthwiseConv2dGPU( ...@@ -561,7 +877,12 @@ void LaunchDepthwiseConv2dGPU(
(SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T); (SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T);
void (*kernel)(const Param, const T*, const T*, T*); void (*kernel)(const Param, const T*, const T*, T*);
if (param.is_compute_deafult) {
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection>; kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection>;
} else {
kernel = DepthwiseConv2dGPUKernelNCHWC32<IConvTrait, kDirection>;
}
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output); kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output);
after_kernel_launch(); after_kernel_launch();
} }
......
...@@ -27,8 +27,10 @@ namespace chanwise { ...@@ -27,8 +27,10 @@ namespace chanwise {
struct Param { struct Param {
uint32_t batch, src_chl, src_h, src_w, chl_mul, flt_h, flt_w, out_h, out_w, pad_h, uint32_t batch, src_chl, src_h, src_w, chl_mul, flt_h, flt_w, out_h, out_w, pad_h,
pad_w, stride_h, stride_w, dilation_h, dilation_w; pad_w, stride_h, stride_w, dilation_h, dilation_w;
bool is_compute_deafult;
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
static Param from_fwd_args(const BiasForwardSizeArgs& args) { static Param from_fwd_args(
const BiasForwardSizeArgs& args, bool is_compute_deafult_ = true) {
#define U(v) static_cast<uint32_t>(v) #define U(v) static_cast<uint32_t>(v)
auto&& src = args.src_layout->shape; auto&& src = args.src_layout->shape;
auto&& dst = args.dst_layout->shape; auto&& dst = args.dst_layout->shape;
...@@ -47,6 +49,7 @@ struct Param { ...@@ -47,6 +49,7 @@ struct Param {
U(fm.spatial[1]), U(dst[hw_pos]), U(dst[hw_pos + 1]), U(fm.spatial[1]), U(dst[hw_pos]), U(dst[hw_pos + 1]),
U(fm.padding[0]), U(fm.padding[1]), U(fm.stride[0]), U(fm.padding[0]), U(fm.padding[1]), U(fm.stride[0]),
U(fm.stride[1]), U(fm.dilation[0]), U(fm.dilation[1]), U(fm.stride[1]), U(fm.dilation[0]), U(fm.dilation[1]),
is_compute_deafult_,
}; };
#undef U #undef U
} }
......
...@@ -47,7 +47,8 @@ bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available( ...@@ -47,7 +47,8 @@ bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available(
if (args.z_layout->ndim > 0) if (args.z_layout->ndim > 0)
return false; return false;
auto param = chanwise::Param::from_fwd_args(args); auto param = chanwise::Param::from_fwd_args(
args, args.opr->param().compute_mode == Param::ComputeMode::DEFAULT);
auto&& fm = args.filter_meta; auto&& fm = args.filter_meta;
return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW && return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW &&
args.src_layout->dtype.category() == DTypeCategory::FLOAT && args.src_layout->dtype.category() == DTypeCategory::FLOAT &&
...@@ -80,7 +81,8 @@ void ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::exec(const ExecArgs& args) c ...@@ -80,7 +81,8 @@ void ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::exec(const ExecArgs& args) c
conv_dst_tensor.layout.dtype); conv_dst_tensor.layout.dtype);
} }
{ {
auto kparam = chanwise::Param::from_fwd_args(args); auto kparam = chanwise::Param::from_fwd_args(
args, args.opr->param().compute_mode == Param::ComputeMode::DEFAULT);
auto stream = cuda_stream(args.handle); auto stream = cuda_stream(args.handle);
switch (args.src_layout->dtype.enumv()) { switch (args.src_layout->dtype.enumv()) {
case DTypeEnum::Float32: case DTypeEnum::Float32:
......
...@@ -27,8 +27,10 @@ namespace chanwise { ...@@ -27,8 +27,10 @@ namespace chanwise {
struct Param { struct Param {
uint32_t batch, src_chl, src_h, src_w, chl_mul, flt_h, flt_w, out_h, out_w, pad_h, uint32_t batch, src_chl, src_h, src_w, chl_mul, flt_h, flt_w, out_h, out_w, pad_h,
pad_w, stride_h, stride_w, dilation_h, dilation_w; pad_w, stride_h, stride_w, dilation_h, dilation_w;
bool is_compute_deafult;
#if MEGDNN_CC_HOST #if MEGDNN_CC_HOST
static Param from_fwd_args(const ForwardSizeArgs& args) { static Param from_fwd_args(
const ForwardSizeArgs& args, bool is_compute_deafult_ = true) {
#define U(v) static_cast<uint32_t>(v) #define U(v) static_cast<uint32_t>(v)
auto&& src = args.src_layout->shape; auto&& src = args.src_layout->shape;
auto&& dst = args.dst_layout->shape; auto&& dst = args.dst_layout->shape;
...@@ -47,6 +49,7 @@ struct Param { ...@@ -47,6 +49,7 @@ struct Param {
U(fm.spatial[1]), U(dst[hw_pos]), U(dst[hw_pos + 1]), U(fm.spatial[1]), U(dst[hw_pos]), U(dst[hw_pos + 1]),
U(fm.padding[0]), U(fm.padding[1]), U(fm.stride[0]), U(fm.padding[0]), U(fm.padding[1]), U(fm.stride[0]),
U(fm.stride[1]), U(fm.dilation[0]), U(fm.dilation[1]), U(fm.stride[1]), U(fm.dilation[0]), U(fm.dilation[1]),
is_compute_deafult_,
}; };
#undef U #undef U
} }
......
...@@ -45,6 +45,15 @@ fma2(const __half2 a, const __half2 b, const __half2 c) { ...@@ -45,6 +45,15 @@ fma2(const __half2 a, const __half2 b, const __half2 c) {
#endif #endif
} }
__device__ __forceinline__ __half2 hadd2(const __half2 a, const __half2 b) {
#if __CUDA_ARCH__ >= 530
return __hadd2(a, b);
#else
return {__float2half(__half2float(a.x) + __half2float(b.x)),
__float2half(__half2float(a.y) + __half2float(b.y))};
#endif
}
__device__ __forceinline__ float2 __device__ __forceinline__ float2
fma2(const __half2 a, const __half2 b, const float2 c) { fma2(const __half2 a, const __half2 b, const float2 c) {
return {__half2float(a.x) * __half2float(b.x) + c.x, return {__half2float(a.x) * __half2float(b.x) + c.x,
......
...@@ -701,7 +701,12 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) { ...@@ -701,7 +701,12 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
ConvBiasForward::algo_name<ConvBias::DirectParam>( ConvBiasForward::algo_name<ConvBias::DirectParam>(
"DEPTHWISE_LARGE_FILTER", {}) "DEPTHWISE_LARGE_FILTER", {})
.c_str())); .c_str()));
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) { for (auto dtype : std::vector<DType> {
dtype::Float32(),
#if CUDA_VERSION >= 9000
dtype::Float16()
#endif
}) {
auto run = [&checker, &dtype]( auto run = [&checker, &dtype](
size_t n, size_t g, size_t h, size_t fh, size_t padding, size_t n, size_t g, size_t h, size_t fh, size_t padding,
size_t stride) { size_t stride) {
......
...@@ -728,7 +728,12 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) { ...@@ -728,7 +728,12 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) {
Checker<ConvolutionBackwardData> checker(handle_cuda()); Checker<ConvolutionBackwardData> checker(handle_cuda());
checker.set_before_exec_callback( checker.set_before_exec_callback(
AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER")); AlgoChecker<ConvolutionBackwardData>("DEPTHWISE_LARGE_FILTER"));
for (auto dtype : std::vector<DType>{dtype::Float32(), dtype::Float16()}) { for (auto dtype : std::vector<DType> {
dtype::Float32(),
#if CUDA_VERSION >= 9000
dtype::Float16()
#endif
}) {
auto run = [&checker, &dtype]( auto run = [&checker, &dtype](
size_t n, size_t g, size_t h, size_t fh, size_t padding, size_t n, size_t g, size_t h, size_t fh, size_t padding,
size_t stride) { size_t stride) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册