提交 d6b098a0 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

refactor(dnn/cuda): remove function name weiming

GitOrigin-RevId: f36495a46a36ab0976f3c70254af95e07cd92a80
上级 ca24c4cd
......@@ -14,43 +14,7 @@
#include "src/cuda/utils.h"
#include "src/cuda/handle.h"
namespace megdnn {
namespace cuda {
namespace local {
void check_input(size_t N,
size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW,
size_t INs, size_t ONs,
size_t PH, size_t PW,
size_t SH, size_t SW,
bool is_xcorr)
{
megdnn_ignore(N);
megdnn_ignore(IC);
megdnn_ignore(IH);
megdnn_ignore(IW);
megdnn_ignore(OC);
megdnn_ignore(OH);
megdnn_ignore(OW);
megdnn_ignore(FH);
megdnn_ignore(FW);
megdnn_ignore(INs);
megdnn_ignore(ONs);
megdnn_ignore(PH);
megdnn_ignore(PW);
megdnn_ignore(SH);
megdnn_ignore(SW);
megdnn_ignore(is_xcorr);
// shared memory constraint
megdnn_assert(IH*IW <= 768, "spatial size should not be larger than 768.");
// megdnn_assert(4 * 4 * 4 * IH * IW <= 49152);
}
} // namespace local
} // namespace cuda
} // namespace megdnn
#include "src/common/utils.cuh"
namespace megdnn {
namespace cuda {
......@@ -94,13 +58,9 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src,
param().stride_h, param().stride_w,
cublas, stream,
one, zero);
} else {
local::check_input(N, IC, IH, IW, OC, OH, OW, FH, FW,
IC*IH*IW, OC*OH*OW,
param().pad_h, param().pad_w,
param().stride_h, param().stride_w,
is_xcorr);
local::forward_proxy_weiming(src.ptr<dt_float32>(),
} else if (local::forward_proxy_default_share_mem_in_bytes(IH, IW) <=
handle->device_prop().sharedMemPerBlock) {
local::forward_proxy_default(src.ptr<dt_float32>(),
filter.ptr<dt_float32>(),
dst.ptr<dt_float32>(),
N,
......@@ -112,6 +72,11 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src,
param().stride_h, param().stride_w,
is_xcorr,
stream);
} else {
megdnn_throw(ssprintf(
"No usable kernel for local conv, src: %s filter: %s \n",
src.layout.to_string().c_str(),
filter.layout.to_string().c_str()));
}
}
......
......@@ -18,6 +18,12 @@ namespace megdnn {
namespace cuda {
namespace local {
constexpr size_t Ns = 4, ICs = 4;
size_t forward_proxy_default_share_mem_in_bytes(size_t IH, size_t IW) {
return Ns * ICs * sizeof(float) * IH * IW;
}
// blockIdx.y is OC*OH*OW/1024
// blockIdx.x is N/4
// threadIdx.x is [0, 1024)
......@@ -96,7 +102,7 @@ __global__ void forward_kernel(const float * __restrict__ src,
}
}
void forward_proxy_weiming(const float *src, const float *filter, float *dst,
void forward_proxy_default(const float *src, const float *filter, float *dst,
size_t N,
size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW,
......@@ -108,7 +114,6 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst,
cudaStream_t stream)
{
size_t threads = 256;
const size_t Ns = 4, ICs = 4;
dim3 blocks = dim3(DIVUP(N, Ns), DIVUP(OC*OH*OW, threads));
if (is_xcorr) {
forward_kernel<Ns, ICs, true><<<blocks, threads,
......
......@@ -17,17 +17,10 @@ namespace megdnn {
namespace cuda {
namespace local {
void check_input(size_t N,
size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW,
size_t INs, size_t ONs,
size_t PH, size_t PW,
size_t SH, size_t SW,
bool is_xcorr);
size_t forward_proxy_default_share_mem_in_bytes(size_t IH, size_t IW);
void forward_proxy_weiming(const float *src, const float *filter, float *dst,
size_t N,
void forward_proxy_default(const float *src, const float *filter, float *dst,
size_t N,
size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW,
......@@ -39,7 +32,7 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst,
/// forward
bool can_forward_proxy_convnet(size_t N,
bool can_forward_proxy_convnet(size_t N,
size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW,
......@@ -70,7 +63,7 @@ size_t get_workspace_in_floats_forward_proxy_convnet(size_t N,
/// bwd data
bool can_backward_data_proxy_convnet(size_t N,
bool can_backward_data_proxy_convnet(size_t N,
size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW,
......@@ -78,7 +71,7 @@ bool can_backward_data_proxy_convnet(size_t N,
size_t PH, size_t PW,
size_t SH, size_t SW);
void backward_data_proxy_convnet(const float *filter,
void backward_data_proxy_convnet(const float *filter,
const float *diff,
float *grad,
float *workspace,
......@@ -103,7 +96,7 @@ size_t get_workspace_in_floats_backward_data_proxy_convnet(size_t N,
/// bwd filter
bool can_backward_filter_proxy_convnet(size_t N,
bool can_backward_filter_proxy_convnet(size_t N,
size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册