提交 8a2e92bd 编写于 作者: M Megvii Engine Team

refactor(cuda): depthwish large kernel

GitOrigin-RevId: dade8710b4122dce27dcebfd74491e810288a160
上级 6b8a69d5
/**
* \file dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
namespace {
#define DIVUP(x, y) (((x) + (y)-1) / (y))
enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD };
template <typename ThreadConfig_, int oh_, int ow_>
struct OutTileConfig {
using ThreadConfig = ThreadConfig_;
static int const unroll_h = oh_;
static int const unroll_w = ThreadConfig::thread_x * ow_;
static int const unroll_size = unroll_h * unroll_w;
static int const block_h = unroll_h * ThreadConfig::thread_y;
static int const block_w = unroll_w;
};
template <int fh_, int fw_>
struct FilterTileConfig {
static int const unroll_h = fh_;
static int const unroll_w = fw_;
static int const unroll_size = unroll_h * unroll_w;
};
template <int x_, int y_>
struct ThreadConfig {
static int const thread_x = x_;
static_assert((thread_x & (thread_x - 1)) == 0, "thread_x must be pow of 2!");
static int const thread_y = y_;
static int const nr_threads = x_ * y_;
};
template <
typename ldg_dtype, typename ThreadConfig_, typename OutTileConfig_,
typename FilterTileConfig_, int stride_w, int stride_h>
struct ConvTraitInner {
using ThreadConfig = ThreadConfig_;
using OutTileConfig = OutTileConfig_;
using FilterTileConfig = FilterTileConfig_;
using CompType = ldg_dtype;
struct SrcTileConfig {
static int const unroll_h =
OutTileConfig::unroll_h + FilterTileConfig::unroll_h - 1;
static int const unroll_w =
(OutTileConfig::unroll_w - 1) * stride_w + FilterTileConfig::unroll_w;
static int const unroll_size = unroll_h * unroll_w;
};
struct SrcTileCount {
static int const smem_src_h =
(OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h;
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_src_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w =
DIVUP((OutTileConfig::block_w - 1) * stride_w +
FilterTileConfig::unroll_w * ThreadConfig::thread_x,
2) *
2;
static int const smem_size = smem_h * smem_w;
static int const load_w =
smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w;
static int const load_h = 1;
static int const reg_h = 1;
static int const reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0;
};
struct FilterTileCount {
static int const smem_flt_h = FilterTileConfig::unroll_h;
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_flt_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w = FilterTileConfig::unroll_w * ThreadConfig::thread_x;
static int const smem_size = smem_h * smem_w;
static int const load_w = smem_w > 32 ? 32 : smem_w;
static int const load_h = ThreadConfig::nr_threads / load_w;
static int const reg_h = 1;
static int const reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0;
};
};
#define CHECK_AB_FWD(a, b) \
if (param.out_w > b * 4) { \
if (param.stride_h == 1 && param.stride_w == 1) { \
using FilterTileConfig_ = FilterTileConfig<unroll_fh, a + 2>; \
using ThreadConfig_ = ThreadConfig<4, 32>; \
using OutTileConfig_ = OutTileConfig<ThreadConfig_, unroll_oh, b + 1>; \
using IConvTrait = ConvTraitInner< \
float, ThreadConfig_, OutTileConfig_, FilterTileConfig_, 1, 1>; \
using SrcTileConfig = typename IConvTrait::SrcTileConfig; \
using SrcTileCount = typename IConvTrait::SrcTileCount; \
using FilterTileCount = typename IConvTrait::FilterTileCount; \
\
if (device_prop.regsPerBlock < \
4 * 32 * \
(FilterTileConfig_::unroll_h * \
FilterTileConfig_::unroll_w * 2 + \
SrcTileConfig::unroll_h * SrcTileConfig::unroll_w) || \
device_prop.sharedMemPerBlock < \
static_cast<size_t>( \
(SrcTileCount::smem_size + \
FilterTileCount::smem_size))) { \
return false; \
} \
return true; \
} else if (param.stride_h == 2 && param.stride_w == 2) { \
using FilterTileConfig_ = FilterTileConfig<unroll_fh, a + 2>; \
using ThreadConfig_ = ThreadConfig<4, 32>; \
using OutTileConfig_ = OutTileConfig<ThreadConfig_, unroll_oh, b + 1>; \
using IConvTrait = ConvTraitInner< \
float, ThreadConfig_, OutTileConfig_, FilterTileConfig_, 2, 2>; \
using SrcTileConfig = typename IConvTrait::SrcTileConfig; \
using SrcTileCount = typename IConvTrait::SrcTileCount; \
using FilterTileCount = typename IConvTrait::FilterTileCount; \
\
if (device_prop.regsPerBlock < \
4 * 32 * \
(FilterTileConfig_::unroll_h * \
FilterTileConfig_::unroll_w * 2 + \
SrcTileConfig::unroll_h * SrcTileConfig::unroll_w) || \
device_prop.sharedMemPerBlock < \
static_cast<size_t>( \
(SrcTileCount::smem_size + \
FilterTileCount::smem_size))) { \
return false; \
} \
return true; \
} \
}
#define CHECK_AB_BWD(a, b) \
if (param.out_w > b * 4) { \
using FilterTileConfig_ = FilterTileConfig<unroll_fh, a + 2>; \
using ThreadConfig_ = ThreadConfig<4, 32>; \
using OutTileConfig_ = OutTileConfig<ThreadConfig_, unroll_oh, b + 1>; \
using IConvTrait = ConvTraitInner< \
float, ThreadConfig_, OutTileConfig_, FilterTileConfig_, 1, 1>; \
using SrcTileConfig = typename IConvTrait::SrcTileConfig; \
using SrcTileCount = typename IConvTrait::SrcTileCount; \
using FilterTileCount = typename IConvTrait::FilterTileCount; \
\
if (device_prop.regsPerBlock < \
4 * 32 * \
(FilterTileConfig_::unroll_h * \
FilterTileConfig_::unroll_w * 2 + \
SrcTileConfig::unroll_h * SrcTileConfig::unroll_w) || \
device_prop.sharedMemPerBlock < \
static_cast<size_t>( \
(SrcTileCount::smem_size + FilterTileCount::smem_size))) { \
return false; \
} \
return true; \
}
#define CHECK_A(a, cb) \
if (param.flt_w > a * 4) { \
CHECK_AB_##cb( \
a, \
15) else CHECK_AB_##cb(a, 14) else CHECK_AB_##cb(a, 13) else CHECK_AB_##cb(a, 12) else CHECK_AB_##cb(a, 11) else CHECK_AB_##cb(a, 10) else CHECK_AB_##cb(a, 9) else CHECK_AB_##cb(a, 8) else CHECK_AB_##cb(a, 7) else CHECK_AB_##cb(a, 6) else CHECK_AB_##cb(a, 5) else CHECK_AB_##cb(a, 4) else CHECK_AB_##cb(a, 3) else CHECK_AB_##cb(a, 2) else CHECK_AB_##cb(a, 1) else CHECK_AB_##cb(a, 0) \
}
#define CHECK(cb) \
CHECK_A(6, cb) \
else CHECK_A(4, cb) else CHECK_A(2, cb) else CHECK_A(0, cb)
} // namespace
/**
* \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.inl
* \file dnn/src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -9,35 +9,10 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#pragma once
#include "depthwise_large_filter.cuh"
#include "src/cuda/cuda_shfl_compat.cuh"
namespace {
enum DepthwiseConv2dDirection { DIRECTION_FORWARD, DIRECTION_BACKWARD };
template <typename ThreadConfig_, int oh_, int ow_>
struct OutTileConfig {
using ThreadConfig = ThreadConfig_;
static int const unroll_h = oh_;
static int const unroll_w = ThreadConfig::thread_x * ow_;
static int const unroll_size = unroll_h * unroll_w;
static int const block_h = unroll_h * ThreadConfig::thread_y;
static int const block_w = unroll_w;
};
template <int fh_, int fw_>
struct FilterTileConfig {
static int const unroll_h = fh_;
static int const unroll_w = fw_;
static int const unroll_size = unroll_h * unroll_w;
};
template <int x_, int y_>
struct ThreadConfig {
static int const thread_x = x_;
static_assert((thread_x & (thread_x - 1)) == 0, "thread_x must be pow of 2!");
static int const thread_y = y_;
static int const nr_threads = x_ * y_;
};
namespace {
template <
typename T, DepthwiseConv2dDirection kDirection, typename ThreadConfig_,
......@@ -87,49 +62,12 @@ struct ConvTrait {
using FilterTileConfig = FilterTileConfig_;
using CompType = ldg_dtype;
struct SrcTileConfig {
static int const unroll_h =
OutTileConfig::unroll_h + FilterTileConfig::unroll_h - 1;
static int const unroll_w =
(OutTileConfig::unroll_w - 1) * stride_w + FilterTileConfig::unroll_w;
static int const unroll_size = unroll_h * unroll_w;
};
struct SrcTileCount {
static int const smem_src_h =
(OutTileConfig::block_h - 1) * stride_h + FilterTileConfig::unroll_h;
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_src_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w =
DIVUP((OutTileConfig::block_w - 1) * stride_w +
FilterTileConfig::unroll_w * ThreadConfig::thread_x,
2) *
2;
static int const smem_size = smem_h * smem_w;
static int const load_w =
smem_w > ThreadConfig::nr_threads ? ThreadConfig::nr_threads : smem_w;
static int const load_h = 1;
static int const reg_h = 1;
static int const reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0;
};
struct FilterTileCount {
static int const smem_flt_h = FilterTileConfig::unroll_h;
static int const smem_buff_h = FilterTileConfig::unroll_h;
static int const smem_load_h = smem_flt_h + smem_buff_h;
static int const smem_h = smem_load_h + smem_buff_h;
static int const smem_w = FilterTileConfig::unroll_w * ThreadConfig::thread_x;
static int const smem_size = smem_h * smem_w;
static int const load_w = smem_w > 32 ? 32 : smem_w;
static int const load_h = ThreadConfig::nr_threads / load_w;
static int const reg_h = 1;
static int const reg_w = DIVUP(smem_w, load_w);
static bool constexpr check_bounds_h = smem_h % load_h != 0;
static bool constexpr check_bounds_w = smem_w % load_w != 0;
};
using CI = ConvTraitInner<
ldg_dtype, ThreadConfig_, OutTileConfig_, FilterTileConfig_, stride_w,
stride_h>;
using SrcTileConfig = typename CI::SrcTileConfig;
using SrcTileCount = typename CI::SrcTileCount;
using FilterTileCount = typename CI::FilterTileCount;
using SrcGlobal2ShareVisitor = Global2SharedMem<
CompType, DepthwiseConv2dDirection::DIRECTION_FORWARD, ThreadConfig,
......@@ -272,14 +210,15 @@ __device__ __forceinline__ void Global2SharedMem<
// CUDA kernel to compute the depthwise convolution forward pass in NCHW format,
// tailored for small images up to 32x32. Stride and depth multiplier must be 1.
// Padding must be 'SAME', which allows to reuse the index computation. Only
// use this kernel if CanLaunchDepthwiseConv2dGPUSmall(args) returns true.
// use this kernel if CanLaunchDepthwiseConv2dGPU(args) returns true.
// Tiles of the input and filter tensors are loaded into shared memory before
// performing the convolution. Each thread handles two elements per iteration,
// one each in the lower and upper half of a tile.
// Backprop input direction is the same as forward direction with the filter
// rotated by 180°.
#if CUDA_VERSION >= 9000
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
__global__ void DepthwiseConv2dGPUKernelNCHWSmall(
__global__ void DepthwiseConv2dGPUKernelNCHW(
const Param param, const __half* input, const __half* filter, __half* output) {
using T = __half;
using T2 = __half2;
......@@ -380,16 +319,18 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
f_w * 2;
reg_flt[0][f_h * t2_flt_unroll_w + f_w] =
*reinterpret_cast<T2*>(smem_flt_ptr + flt_offset);
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = {
f_w > 0 ? reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y
: static_cast<T>(0.0),
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
if (f_w > 0) {
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = {
reg_flt[0][f_h * t2_flt_unroll_w + f_w - 1].y,
reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
} else {
reg_flt[1][f_h * t2_flt_unroll_w + f_w] = {
0.0, reg_flt[0][f_h * t2_flt_unroll_w + f_w].x};
}
}
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
static_cast<T>(0.0), static_cast<T>(0.0)};
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {0.0, 0.0};
reg_flt[1][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 1] = {
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y,
static_cast<T>(0.0)};
reg_flt[0][f_h * t2_flt_unroll_w + t2_flt_unroll_w - 2].y, 0.0};
}
#pragma unroll
......@@ -444,9 +385,10 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
}
}
}
#endif
template <typename ConvTrait, DepthwiseConv2dDirection kDirection>
__global__ void DepthwiseConv2dGPUKernelNCHWSmall(
__global__ void DepthwiseConv2dGPUKernelNCHW(
const Param param, const float* input, const float* filter, float* output) {
using T = float;
using T2 = float2;
......@@ -530,11 +472,6 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
[(off_oh * stride_h + fh + s_h) % SrcTileCount::smem_h *
SrcTileCount::smem_w +
s_w];
if (off_ochannel == 0 && off_obw == 0 && off_obh == 0 && off_oh == 30 &&
off_ow == 0) {
printf("reg_src[%d] = %f\n", s_h * SrcTileConfig::unroll_w + s_w,
reg_src[s_h * SrcTileConfig::unroll_w + s_w]);
}
}
}
......@@ -561,15 +498,6 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw] *
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w + fw +
ow * stride_w];
if (off_ochannel == 0 && off_obw == 0 && off_obh == 0 &&
off_oh == 30) {
printf("sum[%d] += %f * %f\nsum = %f\n",
oh * OutTileConfig::unroll_w + ow,
reg_flt[inner_fh * FilterTileConfig::unroll_w + fw],
reg_src[(inner_fh + oh) * SrcTileConfig::unroll_w +
fw + ow * stride_w],
sum[oh * OutTileConfig::unroll_w + ow]);
}
}
}
}
......@@ -610,7 +538,7 @@ __global__ void DepthwiseConv2dGPUKernelNCHWSmall(
template <
typename T, typename T2, DepthwiseConv2dDirection kDirection, int unroll_fw,
int unroll_ow, int stride>
void LaunchDepthwiseConv2dGPUSmall(
void LaunchDepthwiseConv2dGPU(
const Param& param, const T* input, const T* filter, T* output,
cudaStream_t stream) {
static int const unroll_oh = 1, unroll_fh = 1;
......@@ -633,22 +561,21 @@ void LaunchDepthwiseConv2dGPUSmall(
(SrcTileCount::smem_size + FilterTileCount::smem_size) * sizeof(T);
void (*kernel)(const Param, const T*, const T*, T*);
kernel = DepthwiseConv2dGPUKernelNCHWSmall<IConvTrait, kDirection>;
kernel = DepthwiseConv2dGPUKernelNCHW<IConvTrait, kDirection>;
kernel<<<grid, block, shared_storage, stream>>>(param, input, filter, output);
after_kernel_launch();
}
#define INSTANCE_AB(type1, type2, a, b, direction) \
if (param.out_w > b * 4) { \
printf("param.out_w = %d, b = %d\n", param.out_w, b); \
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \
(param.stride_h == 1 && param.stride_w == 1)) { \
LaunchDepthwiseConv2dGPUSmall<type1, type2, direction, a + 2, b + 1, 1>( \
param, src, flt, dst, stream); \
} else if (param.stride_h == 2 && param.stride_w == 2) { \
LaunchDepthwiseConv2dGPUSmall<type1, type2, direction, a + 2, b + 1, 2>( \
param, src, flt, dst, stream); \
} \
#define INSTANCE_AB(type1, type2, a, b, direction) \
if (param.out_w > b * 4) { \
if (direction == DepthwiseConv2dDirection::DIRECTION_BACKWARD || \
(param.stride_h == 1 && param.stride_w == 1)) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a + 2, b + 1, 1>( \
param, src, flt, dst, stream); \
} else if (param.stride_h == 2 && param.stride_w == 2) { \
LaunchDepthwiseConv2dGPU<type1, type2, direction, a + 2, b + 1, 2>( \
param, src, flt, dst, stream); \
} \
}
#define INSTANCE_A(type1, type2, a, direction) \
......
......@@ -21,7 +21,7 @@ using namespace cuda;
using namespace conv_bias;
using namespace chanwise;
#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl"
#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh"
namespace megdnn {
namespace cuda {
......
......@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh"
#include "src/common/conv_bias.h"
#include "src/cuda/conv_bias/algo.h"
#include "src/cuda/conv_bias/chanwise/kern.cuh"
......@@ -20,26 +21,13 @@ using namespace conv_bias;
namespace {
inline bool is_available_depthwise_large_filter(const chanwise::Param& param) {
auto&& device_prop = cuda::current_device_prop();
int flt_smem_w = (param.flt_w + 3) / 4 * 4;
int flt_smem_h = 3;
int flt_reg_per_thread =
flt_smem_w > 32 ? (flt_smem_w + 31) / 32 : 1 + flt_smem_w / 4;
int ow = param.out_w > 64 ? 64 : param.out_w;
int src_smem_w = ow + flt_smem_w - 1;
int src_smem_h = flt_smem_h + param.flt_h - 1;
int src_reg_per_thread = src_smem_w > 128 ? (flt_smem_w + 127) / 128
: 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1;
int out_reg_per_thread = (ow + 3) / 4 * 4;
if (device_prop.regsPerBlock < 4 * 32 *
(flt_reg_per_thread * 2 +
src_reg_per_thread + out_reg_per_thread) ||
device_prop.sharedMemPerBlock <
static_cast<size_t>(
flt_smem_w * flt_smem_h * 2 + src_smem_w * src_smem_h)) {
return false;
if ((param.stride_h == 1 && param.stride_w == 1) ||
(param.stride_h == 2 && param.stride_w == 2)) {
auto&& device_prop = cuda::current_device_prop();
static int const unroll_oh = 1, unroll_fh = 1;
CHECK(FWD)
}
return true;
return false;
}
} // anonymous namespace
......@@ -64,8 +52,8 @@ bool ConvBiasForwardImpl::AlgoDepthwiseLargeFilter::is_available(
return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW &&
args.src_layout->dtype.category() == DTypeCategory::FLOAT &&
args.opr->param().compute_mode == Param::ComputeMode::DEFAULT &&
fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.icpg == 1 && fm.ocpg == 1 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip &&
is_available_depthwise_large_filter(param);
}
......
......@@ -68,7 +68,7 @@ public:
const TensorLayout& grad);
convolution::ForwardSizeArgs as_fwd_args() const {
return {handle, diff_layout, filter_layout, filter_meta, grad_layout};
return {handle, grad_layout, filter_layout, filter_meta, diff_layout};
}
};
struct ExecArgs : public SizeArgs {
......
......@@ -9,6 +9,7 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "src/cuda/conv_bias/chanwise/depthwise_large_filter.cuh"
#include "src/cuda/convolution/backward_data/algo.h"
#include "src/cuda/convolution/chanwise/kern.cuh"
#include "src/cuda/utils.h"
......@@ -19,29 +20,13 @@ using namespace convolution;
namespace {
inline bool is_available_depthwise_large_filter(const chanwise::Param& param) {
auto&& device_prop = cuda::current_device_prop();
int flt_smem_w = (param.flt_w + 3) / 4 * 4;
int flt_smem_h = 3;
int flt_reg_per_thread =
flt_smem_w > 32 ? (flt_smem_w + 31) / 32 : 1 + flt_smem_w / 4;
int ow = param.out_w > 64 ? 64 : param.out_w;
int src_smem_w = ow + flt_smem_w - 1;
int src_smem_h = flt_smem_h + param.flt_h - 1;
int src_reg_per_thread = src_smem_w > 128 ? (flt_smem_w + 127) / 128
: 1 + (ow + 3) / 4 + flt_smem_w / 4 - 1;
int out_reg_per_thread = (ow + 3) / 4 * 4;
if (device_prop.regsPerBlock < 4 * 32 *
(flt_reg_per_thread * 2 +
src_reg_per_thread + out_reg_per_thread) ||
device_prop.sharedMemPerBlock <
static_cast<size_t>(
flt_smem_w * flt_smem_h * 2 + src_smem_w * src_smem_h)) {
return false;
if ((param.stride_h == 1 && param.stride_w == 1) ||
(param.stride_h == 2 && param.stride_w == 2)) {
auto&& device_prop = cuda::current_device_prop();
static int const unroll_oh = 1, unroll_fh = 1;
CHECK(BWD)
}
printf("param.src_w = %d, param.src_h = %d, param.out_w = %d, param.out_h = %d\n",
param.src_w, param.src_h, param.out_w, param.out_h);
return (param.stride_h == 1 && param.stride_w == 1) ||
(param.stride_h == 2 && param.stride_w == 2);
return false;
}
} // anonymous namespace
......@@ -59,13 +44,15 @@ bool ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::is_available(
return false;
}
auto param = chanwise::Param::from_fwd_args(args.as_fwd_args());
auto param = chanwise::Param::from_fwd_args(
{args.handle, args.diff_layout, args.filter_layout, args.filter_meta,
args.grad_layout});
auto&& fm = args.filter_meta;
return fm.group > 1 && args.filter_meta.format == Param::Format::NCHW &&
args.diff_layout->dtype.category() == DTypeCategory::FLOAT &&
args.opr->param().compute_mode == Param::ComputeMode::DEFAULT &&
fm.spatial_ndim == 2 && fm.icpg == 1 && fm.dilation[0] == 1 &&
fm.dilation[1] == 1 && !fm.should_flip &&
fm.spatial_ndim == 2 && fm.icpg == 1 && fm.ocpg == 1 &&
fm.dilation[0] == 1 && fm.dilation[1] == 1 && !fm.should_flip &&
is_available_depthwise_large_filter(param);
}
......@@ -76,7 +63,9 @@ size_t ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::get_workspace_in_b
void ConvolutionBackwardDataImpl::AlgoDepthwiseLargeFilter::exec(
const ExecArgs& args) const {
auto kparam = chanwise::Param::from_fwd_args(args.as_fwd_args());
auto kparam = chanwise::Param::from_fwd_args(
{args.handle, args.diff_layout, args.filter_layout, args.filter_meta,
args.grad_layout});
auto stream = cuda_stream(args.handle);
switch (args.diff_layout->dtype.enumv()) {
case DTypeEnum::Float32:
......
/**
* \file dnn/src/cuda/conv_bias/chanwise/fwd_depthwise_large_filter.cu
* \file dnn/src/cuda/conv_bias/chanwise/bwd_large_filter.cu
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
......@@ -21,7 +21,7 @@ using namespace cuda;
using namespace convolution;
using namespace chanwise;
#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.inl"
#include "src/cuda/conv_bias/chanwise/depthwise_large_filter_algo.cuh"
namespace megdnn {
namespace cuda {
......
......@@ -739,7 +739,7 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) {
param.sparse = param::Convolution::Sparse::GROUP;
checker.set_dtype(0, dtype).set_dtype(1, dtype).set_dtype(2, dtype);
float scale = 64.f / sqrt(fh * fh);
UniformFloatRNG rng(1.0, 1.0);
UniformFloatRNG rng(scale, scale * 2);
checker.set_rng(0, &rng).set_rng(1, &rng).set_rng(2, &rng);
if (dtype.enumv() == DTypeEnum::Float16)
checker.set_epsilon(1e-1);
......@@ -751,35 +751,35 @@ TEST_F(CUDA, CONVOLUTION_BACKWARD_DEPTHWISE_LARGE_FILTER) {
{n, g, h, h}});
};
run(4, 8, 32, 5, 5 / 2, 1);
run(4, 8, 32, 7, 7/2, 1);
run(4, 8, 32, 9, 9/2, 1);
run(4, 8, 32, 11, 11/2, 1);
run(4, 8, 32, 13, 13/2, 1);
run(4, 8, 32, 15, 15/2, 1);
run(4, 8, 32, 17, 17/2, 1);
run(4, 8, 32, 19, 19/2, 1);
run(4, 8, 32, 21, 21/2, 1);
run(4, 8, 32, 23, 23/2, 1);
run(4, 8, 32, 25, 25/2, 1);
run(4, 8, 32, 27, 27/2, 1);
run(4, 8, 32, 29, 29/2, 1);
run(4, 8, 32, 31, 31/2, 1);
run(4, 8, 32, 7, 7 / 2, 1);
run(4, 8, 32, 9, 9 / 2, 1);
run(4, 8, 32, 11, 11 / 2, 1);
run(4, 8, 32, 13, 13 / 2, 1);
run(4, 8, 32, 15, 15 / 2, 1);
run(4, 8, 32, 17, 17 / 2, 1);
run(4, 8, 32, 19, 19 / 2, 1);
run(4, 8, 32, 21, 21 / 2, 1);
run(4, 8, 32, 23, 23 / 2, 1);
run(4, 8, 32, 25, 25 / 2, 1);
run(4, 8, 32, 27, 27 / 2, 1);
run(4, 8, 32, 29, 29 / 2, 1);
run(4, 8, 32, 31, 31 / 2, 1);
run(4, 8, 64, 5, 5 / 2, 2);
run(4, 8, 64, 7, 7/3, 2);
run(4, 8, 64, 9, 9/3, 2);
run(4, 8, 64, 11, 11/3, 2);
run(4, 8, 64, 13, 13/3, 2);
run(4, 8, 64, 15, 15/3, 2);
run(4, 8, 64, 17, 17/3, 2);
run(4, 8, 64, 19, 19/3, 2);
run(4, 8, 64, 21, 21/3, 2);
run(4, 8, 64, 23, 23/3, 2);
run(4, 8, 64, 25, 25/3, 2);
run(4, 8, 64, 27, 27/3, 2);
run(4, 8, 64, 29, 29/3, 2);
run(4, 8, 64, 31, 31/3, 2);
run(1, 2, 128, 31, 31/3, 2);
run(1, 2, 256, 31, 31/3, 2);
run(4, 8, 64, 7, 7 / 3, 2);
run(4, 8, 64, 9, 9 / 3, 2);
run(4, 8, 64, 11, 11 / 3, 2);
run(4, 8, 64, 13, 13 / 3, 2);
run(4, 8, 64, 15, 15 / 3, 2);
run(4, 8, 64, 17, 17 / 3, 2);
run(4, 8, 64, 19, 19 / 3, 2);
run(4, 8, 64, 21, 21 / 3, 2);
run(4, 8, 64, 23, 23 / 3, 2);
run(4, 8, 64, 25, 25 / 3, 2);
run(4, 8, 64, 27, 27 / 3, 2);
run(4, 8, 64, 29, 29 / 3, 2);
run(4, 8, 64, 31, 31 / 3, 2);
run(1, 2, 128, 31, 31 / 3, 2);
run(1, 2, 256, 31, 31 / 3, 2);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册