提交 d6b098a0 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

refactor(dnn/cuda): remove function name weiming

GitOrigin-RevId: f36495a46a36ab0976f3c70254af95e07cd92a80
上级 ca24c4cd
...@@ -14,43 +14,7 @@ ...@@ -14,43 +14,7 @@
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
namespace megdnn { #include "src/common/utils.cuh"
namespace cuda {
namespace local {
void check_input(size_t N,
size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW,
size_t INs, size_t ONs,
size_t PH, size_t PW,
size_t SH, size_t SW,
bool is_xcorr)
{
megdnn_ignore(N);
megdnn_ignore(IC);
megdnn_ignore(IH);
megdnn_ignore(IW);
megdnn_ignore(OC);
megdnn_ignore(OH);
megdnn_ignore(OW);
megdnn_ignore(FH);
megdnn_ignore(FW);
megdnn_ignore(INs);
megdnn_ignore(ONs);
megdnn_ignore(PH);
megdnn_ignore(PW);
megdnn_ignore(SH);
megdnn_ignore(SW);
megdnn_ignore(is_xcorr);
// shared memory constraint
megdnn_assert(IH*IW <= 768, "spatial size should not be larger than 768.");
// megdnn_assert(4 * 4 * 4 * IH * IW <= 49152);
}
} // namespace local
} // namespace cuda
} // namespace megdnn
namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
...@@ -94,13 +58,9 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, ...@@ -94,13 +58,9 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src,
param().stride_h, param().stride_w, param().stride_h, param().stride_w,
cublas, stream, cublas, stream,
one, zero); one, zero);
} else { } else if (local::forward_proxy_default_share_mem_in_bytes(IH, IW) <=
local::check_input(N, IC, IH, IW, OC, OH, OW, FH, FW, handle->device_prop().sharedMemPerBlock) {
IC*IH*IW, OC*OH*OW, local::forward_proxy_default(src.ptr<dt_float32>(),
param().pad_h, param().pad_w,
param().stride_h, param().stride_w,
is_xcorr);
local::forward_proxy_weiming(src.ptr<dt_float32>(),
filter.ptr<dt_float32>(), filter.ptr<dt_float32>(),
dst.ptr<dt_float32>(), dst.ptr<dt_float32>(),
N, N,
...@@ -112,6 +72,11 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, ...@@ -112,6 +72,11 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src,
param().stride_h, param().stride_w, param().stride_h, param().stride_w,
is_xcorr, is_xcorr,
stream); stream);
} else {
megdnn_throw(ssprintf(
"No usable kernel for local conv, src: %s filter: %s \n",
src.layout.to_string().c_str(),
filter.layout.to_string().c_str()));
} }
} }
......
...@@ -18,6 +18,12 @@ namespace megdnn { ...@@ -18,6 +18,12 @@ namespace megdnn {
namespace cuda { namespace cuda {
namespace local { namespace local {
constexpr size_t Ns = 4, ICs = 4;
size_t forward_proxy_default_share_mem_in_bytes(size_t IH, size_t IW) {
return Ns * ICs * sizeof(float) * IH * IW;
}
// blockIdx.y is OC*OH*OW/1024 // blockIdx.y is OC*OH*OW/1024
// blockIdx.x is N/4 // blockIdx.x is N/4
// threadIdx.x is [0, 1024) // threadIdx.x is [0, 1024)
...@@ -96,7 +102,7 @@ __global__ void forward_kernel(const float * __restrict__ src, ...@@ -96,7 +102,7 @@ __global__ void forward_kernel(const float * __restrict__ src,
} }
} }
void forward_proxy_weiming(const float *src, const float *filter, float *dst, void forward_proxy_default(const float *src, const float *filter, float *dst,
size_t N, size_t N,
size_t IC, size_t IH, size_t IW, size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW, size_t OC, size_t OH, size_t OW,
...@@ -108,7 +114,6 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst, ...@@ -108,7 +114,6 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst,
cudaStream_t stream) cudaStream_t stream)
{ {
size_t threads = 256; size_t threads = 256;
const size_t Ns = 4, ICs = 4;
dim3 blocks = dim3(DIVUP(N, Ns), DIVUP(OC*OH*OW, threads)); dim3 blocks = dim3(DIVUP(N, Ns), DIVUP(OC*OH*OW, threads));
if (is_xcorr) { if (is_xcorr) {
forward_kernel<Ns, ICs, true><<<blocks, threads, forward_kernel<Ns, ICs, true><<<blocks, threads,
......
...@@ -17,17 +17,10 @@ namespace megdnn { ...@@ -17,17 +17,10 @@ namespace megdnn {
namespace cuda { namespace cuda {
namespace local { namespace local {
void check_input(size_t N, size_t forward_proxy_default_share_mem_in_bytes(size_t IH, size_t IW);
size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW,
size_t INs, size_t ONs,
size_t PH, size_t PW,
size_t SH, size_t SW,
bool is_xcorr);
void forward_proxy_weiming(const float *src, const float *filter, float *dst, void forward_proxy_default(const float *src, const float *filter, float *dst,
size_t N, size_t N,
size_t IC, size_t IH, size_t IW, size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW, size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW, size_t FH, size_t FW,
...@@ -39,7 +32,7 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst, ...@@ -39,7 +32,7 @@ void forward_proxy_weiming(const float *src, const float *filter, float *dst,
/// forward /// forward
bool can_forward_proxy_convnet(size_t N, bool can_forward_proxy_convnet(size_t N,
size_t IC, size_t IH, size_t IW, size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW, size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW, size_t FH, size_t FW,
...@@ -70,7 +63,7 @@ size_t get_workspace_in_floats_forward_proxy_convnet(size_t N, ...@@ -70,7 +63,7 @@ size_t get_workspace_in_floats_forward_proxy_convnet(size_t N,
/// bwd data /// bwd data
bool can_backward_data_proxy_convnet(size_t N, bool can_backward_data_proxy_convnet(size_t N,
size_t IC, size_t IH, size_t IW, size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW, size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW, size_t FH, size_t FW,
...@@ -78,7 +71,7 @@ bool can_backward_data_proxy_convnet(size_t N, ...@@ -78,7 +71,7 @@ bool can_backward_data_proxy_convnet(size_t N,
size_t PH, size_t PW, size_t PH, size_t PW,
size_t SH, size_t SW); size_t SH, size_t SW);
void backward_data_proxy_convnet(const float *filter, void backward_data_proxy_convnet(const float *filter,
const float *diff, const float *diff,
float *grad, float *grad,
float *workspace, float *workspace,
...@@ -103,7 +96,7 @@ size_t get_workspace_in_floats_backward_data_proxy_convnet(size_t N, ...@@ -103,7 +96,7 @@ size_t get_workspace_in_floats_backward_data_proxy_convnet(size_t N,
/// bwd filter /// bwd filter
bool can_backward_filter_proxy_convnet(size_t N, bool can_backward_filter_proxy_convnet(size_t N,
size_t IC, size_t IH, size_t IW, size_t IC, size_t IH, size_t IW,
size_t OC, size_t OH, size_t OW, size_t OC, size_t OH, size_t OW,
size_t FH, size_t FW, size_t FH, size_t FW,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册