Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
bc385b53
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
403
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
bc385b53
编写于
2月 07, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(cuda): support float16 depthwise large kernel conv
GitOrigin-RevId: fdc1b15fbcb3968e695601bff6b6a953bf66f115
上级
7d2063e3
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
421 addition
and
89 deletion
+421
-89
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl
...c/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl
+208
-40
dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu
dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu
+10
-1
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
+14
-1
dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp
...cuda/convolution/backward_data/depthwise_large_filter.cpp
+13
-1
dnn/src/cuda/convolution/chanwise/bwd_large_filter.cu
dnn/src/cuda/convolution/chanwise/bwd_large_filter.cu
+10
-1
dnn/test/cuda/conv_bias.cpp
dnn/test/cuda/conv_bias.cpp
+116
-44
dnn/test/cuda/convolution.cpp
dnn/test/cuda/convolution.cpp
+50
-1
未找到文件。
dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl
浏览文件 @
bc385b53
...
@@ -98,9 +98,11 @@ struct ConvTrait {
...
@@ -98,9 +98,11 @@ struct ConvTrait {
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_src_h + smem_buff_h;
static int const smem_load_h = smem_src_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w = OutTileConfig::block_w +
static int const smem_w =
FilterTileConfig::unroll_w * ThreadConfig::thread_x -
DIVUP(OutTileConfig::block_w +
1;
FilterTileConfig::unroll_w * ThreadConfig::thread_x - 1,
2) *
2;
static int const smem_size = smem_h * smem_w;
static int const smem_size = smem_h * smem_w;
static int const load_w =
static int const load_w =
smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w;
smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w;
...
@@ -266,9 +268,11 @@ __device__ __forceinline__ void Global2SharedMem<
...
@@ -266,9 +268,11 @@ __device__ __forceinline__ void Global2SharedMem<
// one each in the lower and upper half of a tile.
// one each in the lower and upper half of a tile.
// Backprop input direction is the same as forward direction with the filter
// Backprop input direction is the same as forward direction with the filter
// rotated by 180°.
// rotated by 180°.
template <typename
T, typename
ConvTrait, DepthwiseConv2dDirection kDirection>
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
__global__ void DepthwiseConv2dGPUKernelNCHWSmall(
__global__ void DepthwiseConv2dGPUKernelNCHWSmall(
const Param param, const T* input, const T* filter, T* output) {
const Param param, const __half* input, const __half* filter, __half* output) {
using T = __half;
using T2 = __half2;
using ThreadConfig = typename ConvTrait::ThreadConfig;
using ThreadConfig = typename ConvTrait::ThreadConfig;
using SrcTileConfig = typename ConvTrait::SrcTileConfig;
using SrcTileConfig = typename ConvTrait::SrcTileConfig;
using FilterTileConfig = typename ConvTrait::FilterTileConfig;
using FilterTileConfig = typename ConvTrait::FilterTileConfig;
...
@@ -282,6 +286,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
...
@@ -282,6 +286,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
off_oh = threadIdx.y, off_ow = threadIdx.x;
const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2;
const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2;
const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2;
extern __shared__ __align__(8) unsigned char smem[];
extern __shared__ __align__(8) unsigned char smem[];
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_src = reinterpret_cast<T*>(smem);
...
@@ -315,10 +323,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
...
@@ -315,10 +323,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
__syncthreads();
__syncthreads();
T
reg_src[SrcTileConfig::unroll_h * SrcTileConfig::
unroll_w],
T
2 reg_src[SrcTileConfig::unroll_h * t2_src_
unroll_w],
reg_flt[
FilterTileConfig::unroll_h * FilterTileConfig::
unroll_w];
reg_flt[
2][FilterTileConfig::unroll_h * t2_flt_
unroll_w];
T
sum[OutTileConfig::unroll_size] = {0.0
};
T
2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}
};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
gl2sh_src.copy();
...
@@ -326,23 +334,34 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
...
@@ -326,23 +334,34 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
#pragma unroll
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
#pragma unroll
for (int s_w = 0; s_w < SrcTileConfig::unroll_w; ++s_w) {
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) {
reg_src[s_h * SrcTileConfig::unroll_w + s_w] = smem_src_ptr
int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h *
[(off_oh + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
SrcTileCount::smem_w +
s_w * 2;
s_w];
reg_src[s_h * t2_src_unroll_w + s_w] =
*reinterpret_cast<T2*>(smem_src_ptr + src_offset);
}
}
}
}
#pragma unroll
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
#pragma unroll
for (int f_w = 0; f_w < FilterTileConfig::unroll_w; ++f_w) {
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) {
reg_flt[f_h * FilterTileConfig::unroll_w + f_w] = smem_flt_ptr
int flt_offset =
[(fh + f_h) % FilterTileCount::smem_h *
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w +
FilterTileCount::smem_w +
f_w * 2;
f_w];
reg_flt[0][f_h * t2_flt_unroll_w + f_w] =
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset);
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = {
f_w > 0 ? reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y
: static_cast<T>(0.0),
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
}
}
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
static_cast<T>(0.0), static_cast<T>(0.0)};
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y,
static_cast<T>(0.0)};
}
}
#pragma unroll
#pragma unroll
...
@@ -350,13 +369,14 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
...
@@ -350,13 +369,14 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
#pragma unroll
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
#pragma unroll
for (int fw = 0; fw <
FilterTileConfig::
unroll_w; ++fw) {
for (int fw = 0; fw <
t2_flt_
unroll_w; ++fw) {
#pragma unroll
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * OutTileConfig::unroll_w + ow] +=
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] *
reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw],
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw +
ow];
ow / 2],
sum[oh * t2_out_unroll_w + ow]);
}
}
}
}
}
}
...
@@ -387,7 +407,156 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
...
@@ -387,7 +407,156 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
if (out_w_idx >= param.out_w)
if (out_w_idx >= param.out_w)
return;
return;
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
sum[i * OutTileConfig::unroll_w + j];
sum[i * OutTileConfig::unroll_w + j].x +
sum[i * OutTileConfig::unroll_w + j].y;
}
}
}
}
}
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
__global__ void DepthwiseConv2dGPUKernelNCHWSmall(
const Param param, const float* input, const float* filter, float* output) {
using T = float;
using T2 = float2;
using ThreadConfig = typename ConvTrait::ThreadConfig;
using SrcTileConfig = typename ConvTrait::SrcTileConfig;
using FilterTileConfig = typename ConvTrait::FilterTileConfig;
using OutTileConfig = typename ConvTrait::OutTileConfig;
using SrcTileCount = typename ConvTrait::SrcTileCount;
using FilterTileCount = typename ConvTrait::FilterTileCount;
using SrcGlobal2ShareVisitor = typename ConvTrait::SrcGlobal2ShareVisitor;
using FilterGlobal2ShareVisitor = typename ConvTrait::FilterGlobal2ShareVisitor;
const bool is_fwd = (kDirection == DepthwiseConv2dDirection::DIRECTION_FORWARD);
int off_ochannel = blockIdx.x, off_obw = blockIdx.y, off_obh = blockIdx.z,
off_oh = threadIdx.y, off_ow = threadIdx.x;
const int t2_src_unroll_w = (SrcTileConfig::unroll_w + 1) / 2;
const int t2_flt_unroll_w = (FilterTileConfig::unroll_w + 2) / 2;
const int t2_out_unroll_w = (OutTileConfig::unroll_w + 1) / 2;
extern __shared__ __align__(8) unsigned char smem[];
static_assert(sizeof(T) <= 8, "Insufficient alignment detected");
T* smem_src = reinterpret_cast<T*>(smem);
T* smem_flt = reinterpret_cast<T*>(&smem_src[SrcTileCount::smem_size]);
int off_ichannel = off_ochannel / param.chl_mul,
off_fchannel = off_ichannel % param.src_chl,
out_start_h = off_obh * OutTileConfig::block_h,
out_start_w = off_obw * OutTileConfig::block_w,
src_start_h = out_start_h - param.pad_h,
src_start_w = out_start_w - param.pad_w,
out_base_h_idx = out_start_h + off_oh * OutTileConfig::unroll_h;
T* smem_src_ptr = smem_src + off_ow * FilterTileConfig::unroll_w;
T* smem_flt_ptr = smem_flt + off_ow * FilterTileConfig::unroll_w;
T* out_base_ptr = output + off_ochannel * param.out_h * param.out_w;
SrcGlobal2ShareVisitor gl2sh_src(
smem_src, param.src_w, src_start_h, src_start_w, param.src_h, param.src_w);
FilterGlobal2ShareVisitor gl2sh_flt = {
smem_flt, param.flt_w, is_fwd ? 0 : param.flt_h - 2,
0, param.flt_h, param.flt_w};
gl2sh_src.g_ptr = input + off_ichannel * param.src_h * param.src_w;
gl2sh_flt.g_ptr = filter + off_fchannel * param.flt_h * param.flt_w;
gl2sh_src.first_copy();
gl2sh_flt.first_copy();
__syncthreads();
T2 reg_src[SrcTileConfig::unroll_h * t2_src_unroll_w],
reg_flt[2][FilterTileConfig::unroll_h * t2_flt_unroll_w];
T2 sum[OutTileConfig::unroll_size] = {{0.0, 0.0}};
for (int fh = 0; fh < param.flt_h; fh += FilterTileConfig::unroll_h) {
gl2sh_src.copy();
gl2sh_flt.copy();
#pragma unroll
for (int s_h = 0; s_h < SrcTileConfig::unroll_h; ++s_h) {
#pragma unroll
for (int s_w = 0; s_w < t2_src_unroll_w; ++s_w) {
int src_offset = (off_oh + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w * 2;
reg_src[s_h * t2_src_unroll_w + s_w] =
*reinterpret_cast<T2*>(smem_src_ptr + src_offset);
}
}
#pragma unroll
for (int f_h = 0; f_h < FilterTileConfig::unroll_h; ++f_h) {
#pragma unroll
for (int f_w = 0; f_w < t2_flt_unroll_w - 1; ++f_w) {
int flt_offset =
(fh + f_h) % FilterTileCount::smem_h * FilterTileCount::smem_w +
f_w * 2;
reg_flt[0][f_h * t2_flt_unroll_w + f_w] =
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset);
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = {
f_w > 0 ? reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y
: static_cast<T>(0.0),
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
}
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
static_cast<T>(0.0), static_cast<T>(0.0)};
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y,
static_cast<T>(0.0)};
}
#pragma unroll
for (int inner_fh = 0; inner_fh < FilterTileConfig::unroll_h; ++inner_fh) {
#pragma unroll
for (int oh = 0; oh < OutTileConfig::unroll_h; ++oh) {
#pragma unroll
for (int fw = 0; fw < t2_flt_unroll_w; ++fw) {
#pragma unroll
for (int ow = 0; ow < OutTileConfig::unroll_w; ++ow) {
sum[oh * t2_out_unroll_w + ow] = megdnn::cuda::fma2(
reg_flt[ow % 2][inner_fh * t2_flt_unroll_w + fw],
reg_src[(inner_fh + oh) * t2_src_unroll_w + fw +
ow / 2],
sum[oh * t2_out_unroll_w + ow]);
}
}
}
}
__syncthreads();
gl2sh_src.commit();
gl2sh_flt.commit();
gl2sh_src.iter_forward();
gl2sh_flt.iter_forward();
__syncthreads();
}
for (int o = 0; o < OutTileConfig::unroll_size; ++o) {
for (int i = 1; i < ThreadConfig::thread_x; i = i << 1) {
sum[o].x += __shfl_xor(sum[o].x, i, 32);
sum[o].y += __shfl_xor(sum[o].y, i, 32);
}
}
if (threadIdx.x == 0) {
#pragma unroll
for (int i = 0; i < OutTileConfig::unroll_h; ++i) {
int out_h_idx = out_base_h_idx + i;
if (out_h_idx < param.out_h) {
#pragma unroll
for (int j = 0; j < OutTileConfig::unroll_w; ++j) {
int out_w_idx = out_start_w + j;
if (out_w_idx >= param.out_w)
return;
out_base_ptr[out_h_idx * param.out_w + out_w_idx] =
sum[i * OutTileConfig::unroll_w + j].x +
sum[i * OutTileConfig::unroll_w + j].y;
}
}
}
}
}
}
...
@@ -419,28 +588,27 @@ void LaunchDepthwiseConv2dGPUSmall(
...
@@ -419,28 +588,27 @@ void LaunchDepthwiseConv2dGPUSmall(
(SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T);
(SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T);
void (*kernel)(const Param, const T*, const T*, T*);
void (*kernel)(const Param, const T*, const T*, T*);
kernel = DepthwiseConv2dGPUKernelNCHWSmall<
T,
IConvTrait, kDirection>;
kernel = DepthwiseConv2dGPUKernelNCHWSmall<IConvTrait, kDirection>;
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output);
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output);
after_kernel_launch();
after_kernel_launch();
}
}
#define INSTANCE_AB(
a, b, direction)
\
#define INSTANCE_AB(
type1, type2, a, b, direction)
\
if (param.out_w > b * 4) {
\
if (param.out_w > b * 4) { \
LaunchDepthwiseConv2dGPUSmall<
float, float2, direction, a + 1
, b + 1>( \
LaunchDepthwiseConv2dGPUSmall<
type1, type2, direction, a + 2
, b + 1>( \
param, src, flt, dst, stream);
\
param, src, flt, dst, stream); \
}
}
#define INSTANCE_A(a, direction) \
#define INSTANCE_A(type1, type2, a, direction) \
if (param.flt_w > 0) { \
if (param.flt_w > a * 4) { \
INSTANCE_AB(a, 15, direction) \
INSTANCE_AB(type1, type2, a, 15, direction) \
else INSTANCE_AB(a, 14, direction) else INSTANCE_AB(a, 13, direction) else INSTANCE_AB( \
else INSTANCE_AB(type1, type2, a, 14, direction) else INSTANCE_AB(type1, type2, a, 13, direction) else INSTANCE_AB(type1, type2, a, 12, direction) else INSTANCE_AB(type1, type2, a, 11, direction) else INSTANCE_AB(type1, type2, a, 10, direction) else INSTANCE_AB( \
a, 12, direction) else INSTANCE_AB(a, 11, direction) else INSTANCE_AB(a, 10, direction) else INSTANCE_AB(a, 9, direction) else INSTANCE_AB(a, 8, direction) else INSTANCE_AB(a, 7, direction) else INSTANCE_AB(a, 6, direction) else INSTANCE_AB(a, 5, direction) else INSTANCE_AB(a, 4, direction) else INSTANCE_AB(a, 3, direction) else INSTANCE_AB(a, 2, direction) else INSTANCE_AB(a, 1, direction) else INSTANCE_AB(a, 0, direction) \
type1, type2, \
a, 9, direction) else INSTANCE_AB(type1, type2, a, 8, direction) else INSTANCE_AB(type1, type2, a, 7, direction) else INSTANCE_AB(type1, type2, a, 6, direction) else INSTANCE_AB(type1, type2, a, 5, direction) else INSTANCE_AB(type1, type2, a, 4, direction) else INSTANCE_AB(type1, type2, a, 3, direction) else INSTANCE_AB(type1, type2, a, 2, direction) else INSTANCE_AB(type1, type2, a, 1, direction) else INSTANCE_AB(type1, type2, a, 0, direction) \
}
}
#define INSTANCE(direction) \
#define INSTANCE(type1, type2, direction) \
INSTANCE_A(7, direction) \
INSTANCE_A(type1, type2, 6, direction) \
else INSTANCE_A(6, direction) else INSTANCE_A(5, direction) else INSTANCE_A(4, direction) else INSTANCE_A( \
else INSTANCE_A(type1, type2, 4, direction) else INSTANCE_A( \
3, \
type1, type2, 2, direction) else INSTANCE_A(type1, type2, 0, direction)
direction) else INSTANCE_A(2, direction) else INSTANCE_A(1, direction) else INSTANCE_A(0, direction)
} // anonymous namespace
} // anonymous namespace
dnn/src/cuda/conv_bias/chanwise/fwd_large_filter.cu
浏览文件 @
bc385b53
...
@@ -37,9 +37,18 @@ template <>
...
@@ -37,9 +37,18 @@ template <>
void
run_fwd_depthwise_large_filter
(
void
run_fwd_depthwise_large_filter
(
float
*
dst
,
const
float
*
src
,
const
float
*
flt
,
const
Param
&
param
,
float
*
dst
,
const
float
*
src
,
const
float
*
flt
,
const
Param
&
param
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
INSTANCE
(
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
)
INSTANCE
(
float
,
float2
,
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
)
}
}
#if CUDA_VERSION >= 9000
template
<
>
void
run_fwd_depthwise_large_filter
(
__half
*
dst
,
const
__half
*
src
,
const
__half
*
flt
,
const
Param
&
param
,
cudaStream_t
stream
)
{
INSTANCE
(
__half
,
__half2
,
DepthwiseConv2dDirection
::
DIRECTION_FORWARD
)
}
#endif
}
// namespace chanwise
}
// namespace chanwise
}
// namespace conv_bias
}
// namespace conv_bias
}
// namespace cuda
}
// namespace cuda
...
...
dnn/src/cuda/conv_bias/depthwise_large_filter.cpp
浏览文件 @
bc385b53
...
@@ -50,7 +50,11 @@ bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available(
...
@@ -50,7 +50,11 @@ bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available(
return
false
;
return
false
;
}
}
if
(
args
.
src_layout
->
dtype
!=
args
.
filter_layout
->
dtype
&&
if
(
args
.
src_layout
->
dtype
!=
args
.
filter_layout
->
dtype
&&
args
.
src_layout
->
dtype
!=
dtype
::
Float32
())
{
(
args
.
src_layout
->
dtype
!=
dtype
::
Float32
()
#if CUDA_VERSION >= 9000
||
args
.
src_layout
->
dtype
!=
dtype
::
Float16
()
#endif
))
{
return
false
;
return
false
;
}
}
if
(
args
.
z_layout
->
ndim
>
0
)
if
(
args
.
z_layout
->
ndim
>
0
)
...
@@ -97,6 +101,15 @@ void ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::exec(const ExecArgs& args) c
...
@@ -97,6 +101,15 @@ void ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::exec(const ExecArgs& args) c
conv_dst_tensor
.
ptr
<
float
>
(),
args
.
src_tensor
->
ptr
<
float
>
(),
conv_dst_tensor
.
ptr
<
float
>
(),
args
.
src_tensor
->
ptr
<
float
>
(),
args
.
filter_tensor
->
ptr
<
float
>
(),
kparam
,
stream
);
args
.
filter_tensor
->
ptr
<
float
>
(),
kparam
,
stream
);
break
;
break
;
#if CUDA_VERSION >= 9000
case
DTypeEnum
::
Float16
:
chanwise
::
run_fwd_depthwise_large_filter
(
static_cast
<
half
*>
(
conv_dst_tensor
.
raw_ptr
()),
static_cast
<
half
*>
(
args
.
src_tensor
->
raw_ptr
()),
static_cast
<
half
*>
(
args
.
filter_tensor
->
raw_ptr
()),
kparam
,
stream
);
break
;
#endif
default:
default:
megdnn_assert_internal
(
0
);
megdnn_assert_internal
(
0
);
}
}
...
...
dnn/src/cuda/convolution/backward_data/depthwise_large_filter.cpp
浏览文件 @
bc385b53
...
@@ -49,7 +49,11 @@ bool ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::is_available(
...
@@ -49,7 +49,11 @@ bool ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::is_available(
return
false
;
return
false
;
}
}
if
(
args
.
diff_layout
->
dtype
!=
args
.
filter_layout
->
dtype
&&
if
(
args
.
diff_layout
->
dtype
!=
args
.
filter_layout
->
dtype
&&
args
.
diff_layout
->
dtype
!=
dtype
::
Float32
())
{
(
args
.
diff_layout
->
dtype
!=
dtype
::
Float32
()
#if CUDA_VERSION >= 9000
||
args
.
diff_layout
->
dtype
!=
dtype
::
Float16
()
#endif
))
{
return
false
;
return
false
;
}
}
...
@@ -78,6 +82,14 @@ void ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::exec(
...
@@ -78,6 +82,14 @@ void ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::exec(
args
.
grad_tensor
->
ptr
<
float
>
(),
args
.
diff_tensor
->
ptr
<
float
>
(),
args
.
grad_tensor
->
ptr
<
float
>
(),
args
.
diff_tensor
->
ptr
<
float
>
(),
args
.
filter_tensor
->
ptr
<
float
>
(),
kparam
,
stream
);
args
.
filter_tensor
->
ptr
<
float
>
(),
kparam
,
stream
);
break
;
break
;
#if CUDA_VERSION >= 9000
case
DTypeEnum
::
Float16
:
chanwise
::
run_bwd_depthwise_large_filter
(
static_cast
<
half
*>
(
args
.
grad_tensor
->
raw_ptr
()),
static_cast
<
half
*>
(
args
.
diff_tensor
->
raw_ptr
()),
static_cast
<
half
*>
(
args
.
filter_tensor
->
raw_ptr
()),
kparam
,
stream
);
break
;
#endif
default:
default:
megdnn_assert_internal
(
0
);
megdnn_assert_internal
(
0
);
}
}
...
...
dnn/src/cuda/convolution/chanwise/bwd_large_filter.cu
浏览文件 @
bc385b53
...
@@ -34,9 +34,18 @@ template <>
...
@@ -34,9 +34,18 @@ template <>
void
run_bwd_depthwise_large_filter
(
void
run_bwd_depthwise_large_filter
(
float
*
dst
,
const
float
*
src
,
const
float
*
flt
,
const
Param
&
param
,
float
*
dst
,
const
float
*
src
,
const
float
*
flt
,
const
Param
&
param
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
INSTANCE
(
DepthwiseConv2dDirection
::
DIRECTION_BACKWARD
)
INSTANCE
(
float
,
float2
,
DepthwiseConv2dDirection
::
DIRECTION_BACKWARD
)
}
}
#if CUDA_VERSION >= 9000
template
<
>
void
run_bwd_depthwise_large_filter
(
__half
*
dst
,
const
__half
*
src
,
const
__half
*
flt
,
const
Param
&
param
,
cudaStream_t
stream
)
{
INSTANCE
(
__half
,
__half2
,
DepthwiseConv2dDirection
::
DIRECTION_BACKWARD
)
}
#endif
}
// namespace chanwise
}
// namespace chanwise
}
// namespace convolution
}
// namespace convolution
}
// namespace cuda
}
// namespace cuda
...
...
dnn/test/cuda/conv_bias.cpp
浏览文件 @
bc385b53
...
@@ -701,51 +701,53 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
...
@@ -701,51 +701,53 @@ TEST_F(CUDA, CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
ConvBiasForward
::
algo_name
<
ConvBias
::
DirectParam
>
(
ConvBiasForward
::
algo_name
<
ConvBias
::
DirectParam
>
(
"DEPTHWISE_LARGE_FILTER"
,
{})
"DEPTHWISE_LARGE_FILTER"
,
{})
.
c_str
()));
.
c_str
()));
auto
run
=
[
&
checker
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
)
{
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float16
()})
{
param
::
ConvBias
cur_param
;
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
)
{
cur_param
.
mode
=
param
::
ConvBias
::
Mode
::
CROSS_CORRELATION
;
param
::
ConvBias
cur_param
;
cur_param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
GROUP
;
cur_param
.
mode
=
param
::
ConvBias
::
Mode
::
CROSS_CORRELATION
;
checker
.
set_dtype
(
0
,
dtype
::
Float32
())
cur_param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
GROUP
;
.
set_dtype
(
1
,
dtype
::
Float32
())
checker
.
set_dtype
(
0
,
dtype
)
.
set_dtype
(
2
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
)
.
set_dtype
(
3
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
)
.
set_dtype
(
4
,
dtype
::
Float32
());
.
set_dtype
(
3
,
dtype
)
.
set_dtype
(
4
,
dtype
);
cur_param
.
pad_h
=
cur_param
.
pad_w
=
fh
/
2
;
cur_param
.
pad_h
=
cur_param
.
pad_w
=
fh
/
2
;
cur_param
.
stride_h
=
cur_param
.
stride_w
=
1
;
cur_param
.
stride_h
=
cur_param
.
stride_w
=
1
;
checker
.
set_param
(
cur_param
).
execs
(
checker
.
set_param
(
cur_param
).
execs
(
{{
n
,
g
,
h
,
h
},
{
g
,
1
,
1
,
fh
,
fh
},
{},
{},
{}});
{{
n
,
g
,
h
,
h
},
{
g
,
1
,
1
,
fh
,
fh
},
{},
{},
{}});
};
};
run
(
4
,
8
,
32
,
5
);
run
(
4
,
8
,
32
,
5
);
run
(
4
,
8
,
32
,
7
);
run
(
4
,
8
,
32
,
7
);
run
(
4
,
8
,
32
,
9
);
run
(
4
,
8
,
32
,
9
);
run
(
4
,
8
,
32
,
11
);
run
(
4
,
8
,
32
,
11
);
run
(
4
,
8
,
32
,
13
);
run
(
4
,
8
,
32
,
13
);
run
(
4
,
8
,
32
,
15
);
run
(
4
,
8
,
32
,
15
);
run
(
4
,
8
,
32
,
17
);
run
(
4
,
8
,
32
,
17
);
run
(
4
,
8
,
32
,
19
);
run
(
4
,
8
,
32
,
19
);
run
(
4
,
8
,
32
,
21
);
run
(
4
,
8
,
32
,
21
);
run
(
4
,
8
,
32
,
23
);
run
(
4
,
8
,
32
,
23
);
run
(
4
,
8
,
32
,
25
);
run
(
4
,
8
,
32
,
25
);
run
(
4
,
8
,
32
,
27
);
run
(
4
,
8
,
32
,
27
);
run
(
4
,
8
,
32
,
29
);
run
(
4
,
8
,
32
,
29
);
run
(
4
,
8
,
32
,
31
);
run
(
4
,
8
,
32
,
31
);
run
(
4
,
8
,
64
,
5
);
run
(
4
,
8
,
64
,
5
);
run
(
4
,
8
,
64
,
7
);
run
(
4
,
8
,
64
,
7
);
run
(
4
,
8
,
64
,
9
);
run
(
4
,
8
,
64
,
9
);
run
(
4
,
8
,
64
,
11
);
run
(
4
,
8
,
64
,
11
);
run
(
4
,
8
,
64
,
13
);
run
(
4
,
8
,
64
,
13
);
run
(
4
,
8
,
64
,
15
);
run
(
4
,
8
,
64
,
15
);
run
(
4
,
8
,
64
,
17
);
run
(
4
,
8
,
64
,
17
);
run
(
4
,
8
,
64
,
19
);
run
(
4
,
8
,
64
,
19
);
run
(
4
,
8
,
64
,
21
);
run
(
4
,
8
,
64
,
21
);
run
(
4
,
8
,
64
,
23
);
run
(
4
,
8
,
64
,
23
);
run
(
4
,
8
,
64
,
25
);
run
(
4
,
8
,
64
,
25
);
run
(
4
,
8
,
64
,
27
);
run
(
4
,
8
,
64
,
27
);
run
(
4
,
8
,
64
,
29
);
run
(
4
,
8
,
64
,
29
);
run
(
4
,
8
,
64
,
31
);
run
(
4
,
8
,
64
,
31
);
run
(
1
,
2
,
128
,
31
);
run
(
1
,
2
,
128
,
31
);
run
(
1
,
2
,
256
,
31
);
run
(
1
,
2
,
256
,
31
);
}
}
}
TEST_F
(
CUDA
,
CONV_BIAS_FORWARD_CHANWISE_8x8x32
)
{
TEST_F
(
CUDA
,
CONV_BIAS_FORWARD_CHANWISE_8x8x32
)
{
...
@@ -1550,11 +1552,81 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
...
@@ -1550,11 +1552,81 @@ TEST_F(CUDA, BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER) {
param
.
stride_h
=
sh
;
param
.
stride_h
=
sh
;
param
.
stride_w
=
sw
;
param
.
stride_w
=
sw
;
bencher
.
set_times
(
nr_times
);
size_t
ho
=
infer_conv_shape
(
hi
,
fh
,
sh
,
param
.
pad_h
);
size_t
wo
=
infer_conv_shape
(
wi
,
fw
,
sw
,
param
.
pad_w
);
TensorShape
inp
{
batch
,
g
,
hi
,
wi
},
kern
{
g
,
1
,
1
,
fh
,
fw
},
out
{
batch
,
g
,
ho
,
wo
};
float
bandwith
=
static_cast
<
float
>
(
inp
.
total_nr_elems
()
+
kern
.
total_nr_elems
()
+
out
.
total_nr_elems
())
/
(
1024
*
1024
*
1024
)
*
1e3
;
bencher
.
set_param
(
param
)
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
0
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
1
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Float32
())
.
set_dtype
(
2
,
dtype
::
Float32
())
.
set_dtype
(
4
,
dtype
::
Float32
());
.
set_dtype
(
4
,
dtype
::
Float32
());
auto
fp32_time_in_ms
=
bencher
.
execs
({
inp
,
kern
,
{},
{},
out
})
/
nr_times
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
set_dtype
(
2
,
dtype
::
Float16
())
.
set_dtype
(
4
,
dtype
::
Float16
());
auto
fp16_time_in_ms
=
bencher
.
execs
({
inp
,
kern
,
{},
{},
out
})
/
nr_times
;
printf
(
"chanwise_depthwise_large_filter: inp=%s, kern=%s, out=%s, fp32_time: "
"%.2fms, fp16_time: %.2fms, speedup: %0.2f (fp16/fp32) "
"fp32_bandwidth: %.2fGB/s fp16_bandwidth: %.2fGB/s.
\n
"
,
inp
.
to_string
().
c_str
(),
kern
.
to_string
().
c_str
(),
out
.
to_string
().
c_str
(),
fp32_time_in_ms
,
fp16_time_in_ms
,
fp32_time_in_ms
/
fp16_time_in_ms
,
bandwith
*
4
/
fp32_time_in_ms
,
bandwith
*
2
/
fp16_time_in_ms
);
};
run_bench
(
64
,
384
,
32
,
32
,
3
,
3
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
5
,
5
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
7
,
7
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
9
,
9
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
11
,
11
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
13
,
13
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
15
,
15
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
17
,
17
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
19
,
19
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
21
,
21
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
23
,
23
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
25
,
25
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
27
,
27
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
29
,
29
,
1
,
1
,
10
);
run_bench
(
64
,
384
,
32
,
32
,
31
,
31
,
1
,
1
,
10
);
}
TEST_F
(
CUDA
,
BENCHMARK_CONV_BIAS_FORWARD_DEPTHWISE_LARGE_FILTER_FP16
)
{
require_compute_capability
(
7
,
5
);
Benchmarker
<
ConvBiasForward
>
bencher
(
handle_cuda
());
bencher
.
set_display
(
false
);
bencher
.
set_before_exec_callback
(
conv_bias
::
ConvBiasAlgoChecker
<
ConvBiasForward
>
(
ConvBiasForward
::
algo_name
<
ConvBiasForward
::
DirectParam
>
(
"DEPTHWISE_LARGE_FILTER"
,
{})
.
c_str
()));
ConvBias
::
Param
param
;
param
.
format
=
ConvBias
::
Param
::
Format
::
NCHW
;
using
NonlineMode
=
ConvBias
::
Param
::
NonlineMode
;
param
.
nonlineMode
=
NonlineMode
::
IDENTITY
;
param
.
sparse
=
ConvBias
::
Param
::
Sparse
::
GROUP
;
auto
run_bench
=
[
&
](
size_t
batch
,
size_t
g
,
size_t
hi
,
size_t
wi
,
size_t
fh
,
size_t
fw
,
size_t
sh
,
size_t
sw
,
size_t
nr_times
)
{
param
.
pad_h
=
fh
/
2
;
param
.
pad_w
=
fw
/
2
;
param
.
stride_h
=
sh
;
param
.
stride_w
=
sw
;
bencher
.
set_param
(
param
)
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
set_dtype
(
2
,
dtype
::
Float16
())
.
set_dtype
(
4
,
dtype
::
Float16
());
bencher
.
set_times
(
nr_times
);
bencher
.
set_times
(
nr_times
);
size_t
ho
=
infer_conv_shape
(
hi
,
fh
,
sh
,
param
.
pad_h
);
size_t
ho
=
infer_conv_shape
(
hi
,
fh
,
sh
,
param
.
pad_h
);
size_t
wo
=
infer_conv_shape
(
wi
,
fw
,
sw
,
param
.
pad_w
);
size_t
wo
=
infer_conv_shape
(
wi
,
fw
,
sw
,
param
.
pad_w
);
...
...
dnn/test/cuda/convolution.cpp
浏览文件 @
bc385b53
...
@@ -728,7 +728,7 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) {
...
@@ -728,7 +728,7 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) {
Checker
<
ConvolutionBackwardData
>
checker
(
handle_cuda
());
Checker
<
ConvolutionBackwardData
>
checker
(
handle_cuda
());
checker
.
set_before_exec_callback
(
checker
.
set_before_exec_callback
(
AlgoChecker
<
ConvolutionBackwardData
>
(
"DEPTHWISE_LARGE_FILTER"
));
AlgoChecker
<
ConvolutionBackwardData
>
(
"DEPTHWISE_LARGE_FILTER"
));
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float
32
()})
{
for
(
auto
dtype
:
std
::
vector
<
DType
>
{
dtype
::
Float
16
()})
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
)
{
auto
run
=
[
&
checker
,
&
dtype
](
size_t
n
,
size_t
g
,
size_t
h
,
size_t
fh
)
{
param
::
Convolution
param
;
param
::
Convolution
param
;
param
.
stride_h
=
param
.
stride_w
=
1
;
param
.
stride_h
=
param
.
stride_w
=
1
;
...
@@ -999,6 +999,55 @@ TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER) {
...
@@ -999,6 +999,55 @@ TEST_F(CUDA, BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER) {
run
(
64
,
384
,
384
,
32
,
32
,
31
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
31
,
1
,
10
);
}
}
TEST_F
(
CUDA
,
BENCHMARK_CONVOLUTION_BWD_DATA_DEPTHWISE_LARGE_FILTER_FP16
)
{
CUBenchmarker
<
ConvolutionBackwardData
>
bencher
{
handle_cuda
()};
bencher
.
set_display
(
false
);
bencher
.
set_before_exec_callback
(
AlgoChecker
<
ConvolutionBackwardData
>
(
"DEPTHWISE_LARGE_FILTER"
));
auto
run
=
[
&
](
size_t
N
,
size_t
OC
,
size_t
g
,
size_t
IH
,
size_t
IW
,
size_t
FH
,
size_t
SH
,
size_t
nr_times
)
{
bencher
.
set_dtype
(
0
,
dtype
::
Float16
())
.
set_dtype
(
1
,
dtype
::
Float16
())
.
set_dtype
(
2
,
dtype
::
Float16
());
param
::
Convolution
param
;
param
.
stride_h
=
param
.
stride_w
=
SH
;
param
.
pad_h
=
param
.
pad_w
=
FH
/
2
;
param
.
sparse
=
param
::
Convolution
::
Sparse
::
GROUP
;
bencher
.
set_param
(
param
);
bencher
.
set_times
(
nr_times
);
TensorLayout
src
{{
N
,
g
,
IH
,
IW
},
dtype
::
Float16
()},
filter
{{
g
,
1
,
1
,
FH
,
FH
},
dtype
::
Float16
()};
TensorLayout
dst
;
{
auto
&&
opr
=
handle_cuda
()
->
create_operator
<
Convolution
>
();
opr
->
param
()
=
param
;
opr
->
deduce_layout
(
src
,
filter
,
dst
);
}
auto
time_ms_fp16
=
bencher
.
execl
({
filter
,
dst
,
src
})
/
nr_times
;
float
flo
=
2.0
*
N
*
g
*
dst
[
2
]
*
dst
[
3
]
*
FH
*
FH
;
printf
(
"inp=%s, kern=%s, dst=%s "
,
src
.
to_string
().
c_str
(),
filter
.
to_string
().
c_str
(),
dst
.
to_string
().
c_str
());
printf
(
"time_fp16=%.2fms, flops=%.3fTFLOPS
\n
"
,
time_ms_fp16
,
(
flo
/
(
time_ms_fp16
*
1e9
)));
};
run
(
64
,
384
,
384
,
32
,
32
,
3
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
5
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
7
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
9
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
11
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
13
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
15
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
17
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
19
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
21
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
23
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
25
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
27
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
29
,
1
,
10
);
run
(
64
,
384
,
384
,
32
,
32
,
31
,
1
,
10
);
}
TEST_F
(
CUDA
,
BENCHMARK_CONVOLUTION_BWD_DATA_BF16
)
{
TEST_F
(
CUDA
,
BENCHMARK_CONVOLUTION_BWD_DATA_BF16
)
{
CUBenchmarker
<
ConvolutionBackwardData
>
bench
{
handle_cuda
()};
CUBenchmarker
<
ConvolutionBackwardData
>
bench
{
handle_cuda
()};
std
::
unique_ptr
<
OprProxy
<
ConvolutionBackwardData
>>
proxy
{
std
::
unique_ptr
<
OprProxy
<
ConvolutionBackwardData
>>
proxy
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录