提交 ca24c4cd 编写于 作者: M Megvii Engine Team 提交者: Xinran Xu

refactor(dnn/cuda): invoke local opr in group local

GitOrigin-RevId: aa636041bb07c8e99ea4e5c7ae21d0f48f97ca30
上级 91c3d5fe
...@@ -29,7 +29,11 @@ void LocalBase::deduce_layout_fwd(const TensorLayout &src, ...@@ -29,7 +29,11 @@ void LocalBase::deduce_layout_fwd(const TensorLayout &src,
auto errmsg_c = errmsg.c_str(); auto errmsg_c = errmsg.c_str();
MEGDNN_MARK_USED_VAR(errmsg_c); MEGDNN_MARK_USED_VAR(errmsg_c);
megdnn_assert_contiguous(src); //! in batch dim we don't need contiguous
TensorLayout src_contig = src;
src_contig.init_contiguous_stride();
src_contig.stride[0] = src.stride[0];
megdnn_assert_eq_layout(src_contig, src);
megdnn_assert_contiguous(filter); megdnn_assert_contiguous(filter);
megdnn_assert(src.ndim == 4_z, "%s", errmsg_c); megdnn_assert(src.ndim == 4_z, "%s", errmsg_c);
megdnn_assert(filter.ndim == 6_z, "%s", errmsg_c); megdnn_assert(filter.ndim == 6_z, "%s", errmsg_c);
...@@ -67,6 +71,8 @@ void LocalBase::check_layout_fwd(const TensorLayout &src, ...@@ -67,6 +71,8 @@ void LocalBase::check_layout_fwd(const TensorLayout &src,
megdnn_assert_eq_dtype(src, filter); megdnn_assert_eq_dtype(src, filter);
megdnn_assert_eq_dtype(src, dst); megdnn_assert_eq_dtype(src, dst);
deduce_layout_fwd(src, filter, dst_expected); deduce_layout_fwd(src, filter, dst_expected);
//! in batch dim we don't need contiguous
dst_expected.stride[0] = dst.stride[0];
megdnn_assert_eq_layout(dst_expected, dst); megdnn_assert_eq_layout(dst_expected, dst);
megdnn_assert(src.dtype == filter.dtype && src.dtype == dst.dtype); megdnn_assert(src.dtype == filter.dtype && src.dtype == dst.dtype);
......
...@@ -6,141 +6,121 @@ ...@@ -6,141 +6,121 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/cuda/group_local/opr_impl.h" #include "src/cuda/group_local/opr_impl.h"
#include <memory>
#include "megdnn/opr_param_defs.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/cuda/local/local.cuh"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
#include "src/cuda/group_local/forward/kern.cuh" #include "src/cuda/group_local/forward/kern.cuh"
#include "src/cuda/local/opr_impl.h"
#include "src/cuda/local/local.cuh"
using namespace megdnn;
using namespace cuda;
namespace {
std::unique_ptr<LocalForward> get_opr(Handle* handle,
param::Convolution param) {
auto&& opr = handle->create_operator<LocalForward>();
opr->param() = param;
return std::move(opr);
}
template <typename T>
void incr_ptr(T*& dst, ptrdiff_t delta) {
dst = reinterpret_cast<T*>(reinterpret_cast<uintptr_t>(dst) + delta);
}
TensorLayout prepare_src_dst(const TensorLayout& input, size_t g) {
TensorLayout ret = input;
megdnn_assert(ret[1] % g == 0);
ret[1] = ret[1] / g;
ret.init_contiguous_stride();
//! change stride of batch
ret.stride[0] = input.stride[0];
return ret;
}
TensorLayout prepare_filter(const TensorLayout& filter) {
//! group, OH, OW, ICg, FH, FW, OCg -> OH, OW, IcCg, FH, FW, OCg
return {{filter[1], filter[2], filter[3], filter[4], filter[5], filter[6]},
filter.dtype};
}
namespace megdnn { } // namespace
namespace cuda {
void GroupLocalForwardImpl::exec(_megdnn_tensor_in src, void GroupLocalForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in filter, _megdnn_tensor_in filter,
_megdnn_tensor_out dst, _megdnn_tensor_out dst,
_megdnn_workspace workspace) _megdnn_workspace workspace) {
{
megdnn_assert(src.layout.dtype == dtype::Float32(), megdnn_assert(src.layout.dtype == dtype::Float32(),
"cuda do not support fp16 group local operator"); "cuda do not support fp16 group local operator");
check_exec(src.layout, filter.layout, dst.layout, workspace.size); check_exec(src.layout, filter.layout, dst.layout, workspace.size);
auto handle = concrete_handle(this->handle());
auto G = filter.layout[0]; auto G = filter.layout[0];
auto N = src.layout.shape[0], IC = src.layout.shape[1]/G, auto IH = src.layout.shape[2], IW = src.layout.shape[3],
IH = src.layout.shape[2], IW = src.layout.shape[3],
OC = dst.layout.shape[1]/G,
OH = dst.layout.shape[2], OW = dst.layout.shape[3]; OH = dst.layout.shape[2], OW = dst.layout.shape[3];
auto FH = filter.layout.shape[4], FW = filter.layout.shape[5];
auto PH = param().pad_h, PW = param().pad_w;
auto SH = param().stride_h, SW = param().stride_w;
const float *sptr = src.ptr<dt_float32>();
const float *fptr = filter.ptr<dt_float32>();
float *dptr = dst.ptr<dt_float32>();
float *wptr = workspace.ptr<dt_float32>();
auto handle = concrete_handle(this->handle());
auto stream = cuda_stream(this->handle());
auto cublas = cublas_handle(this->handle());
auto one = handle->one_device();
auto zero = handle->zero_device();
if (prefer_inference_kernel(src.layout, filter.layout, dst.layout)) { if (prefer_inference_kernel(src.layout, filter.layout, dst.layout)) {
group_local::exec(sptr, fptr, dptr, wptr, auto N = src.layout.shape[0], ICg = src.layout.shape[1] / G,
N, IC, IH, IW, OCg = dst.layout.shape[1] / G;
OC, OH, OW, auto FH = filter.layout.shape[4], FW = filter.layout.shape[5];
FH, FW, auto PH = param().pad_h, PW = param().pad_w;
G, auto SH = param().stride_h, SW = param().stride_w;
PH, PW, const float* sptr = src.ptr<dt_float32>();
SH, SW, const float* fptr = filter.ptr<dt_float32>();
stream float* dptr = dst.ptr<dt_float32>();
); float* wptr = workspace.ptr<dt_float32>();
} else if (local::can_forward_proxy_convnet(N, IC, IH, IW, auto stream = cuda_stream(this->handle());
OC, OH, OW,
FH, FW, group_local::exec(sptr, fptr, dptr, wptr, N, ICg, IH, IW, OCg, OH, OW,
G*IC*IH*IW, G*OC*OH*OW, FH, FW, G, PH, PW, SH, SW, stream);
PH, PW,
SH, SW))
{
// use convnet
for (size_t g = 0; g < G; ++g) {
local::forward_proxy_convnet(sptr + g*IC*IH*IW,
fptr + g*OH*OW*IC*FH*FW*OC,
dptr + g*OC*OH*OW,
wptr,
N, IC, IH, IW,
OC, OH, OW,
FH, FW,
G*IC*IH*IW, G*OC*OH*OW,
PH, PW,
SH, SW,
cublas, stream, one, zero);
}
} else { } else {
local::check_input(N, IC, IH, IW, OC, OH, OW, FH, FW, auto&& opr = get_opr(handle, param());
G*IC*IH*IW, G*OC*OH*OW, TensorND src_g = {src.raw_ptr, prepare_src_dst(src.layout, G)};
PH, PW, TensorND dst_g = {dst.raw_ptr, prepare_src_dst(dst.layout, G)};
SH, SW, TensorND filter_g = {filter.raw_ptr, prepare_filter(filter.layout)};
true);
// do not use convnet
for (size_t g = 0; g < G; ++g) { for (size_t g = 0; g < G; ++g) {
local::forward_proxy_weiming(sptr + g*IC*IH*IW, opr->exec(src_g, filter_g, dst_g, workspace);
fptr + g*OH*OW*IC*FH*FW*OC, incr_ptr(src_g.raw_ptr, src_g.layout.stride[1] *
dptr + g*OC*OH*OW, src_g.layout.shape[1] *
N, IC, IH, IW, src_g.layout.dtype.size());
OC, OH, OW, incr_ptr(dst_g.raw_ptr, dst_g.layout.stride[1] *
FH, FW, dst_g.layout.shape[1] *
G*IC*IH*IW, G*OC*OH*OW, dst_g.layout.dtype.size());
PH, PW, incr_ptr(filter_g.raw_ptr, filter_g.layout.span().dist_byte());
SH, SW,
true, stream);
} }
} }
} }
GroupLocalForwardImpl::GroupLocalForwardImpl(Handle *handle): GroupLocalForwardImpl::GroupLocalForwardImpl(Handle* handle)
GroupLocalForward(handle) : GroupLocalForward(handle) {}
{
}
size_t GroupLocalForwardImpl::get_workspace_in_bytes(const TensorLayout &src, size_t GroupLocalForwardImpl::get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout &filter, const TensorLayout& filter,
const TensorLayout &dst) const TensorLayout& dst) {
{
auto G = filter[0];
auto N = src.shape[0], IC = src.shape[1]/G,
IH = src.shape[2], IW = src.shape[3],
OC = dst.shape[1]/G,
OH = dst.shape[2], OW = dst.shape[3];
auto FH = filter.shape[4], FW = filter.shape[5];
auto PH = param().pad_h, PW = param().pad_w;
auto SH = param().stride_h, SW = param().stride_w;
if (prefer_inference_kernel(src, filter, dst)) { if (prefer_inference_kernel(src, filter, dst)) {
return 0; return 0;
} else if (local::can_forward_proxy_convnet(N, IC, IH, IW,
OC, OH, OW,
FH, FW,
G*IC*IH*IW, G*OC*OH*OW,
PH, PW,
SH, SW))
{
auto res = local::get_workspace_in_floats_forward_proxy_convnet(N,
IC, IH, IW,
OC, OH, OW,
FH, FW,
G*IC*IH*IW, G*OC*OH*OW,
PH, PW,
SH, SW) * sizeof(float);
return res;
} else { } else {
return 0; auto G = filter[0];
TensorLayout src_g = prepare_src_dst(src, G);
TensorLayout dst_g = prepare_src_dst(dst, G);
TensorLayout filter_g = prepare_filter(filter);
return get_opr(handle(), param())
->get_workspace_in_bytes(src_g, filter_g, dst_g);
} }
} }
bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout &src, bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout& src,
const TensorLayout &filter, const TensorLayout& filter,
const TensorLayout &dst) const TensorLayout& dst) {
{
MEGDNN_MARK_USED_VAR(filter); MEGDNN_MARK_USED_VAR(filter);
MEGDNN_MARK_USED_VAR(dst); MEGDNN_MARK_USED_VAR(dst);
auto handle = concrete_handle(this->handle()); auto handle = concrete_handle(this->handle());
...@@ -149,6 +129,4 @@ bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout &src, ...@@ -149,6 +129,4 @@ bool GroupLocalForwardImpl::prefer_inference_kernel(const TensorLayout &src,
group_local::get_share_mem_in_bytes(IH, IW); group_local::get_share_mem_in_bytes(IH, IW);
} }
} // namespace cuda // vim: syntax=cpp.doxygen
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -78,6 +78,8 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, ...@@ -78,6 +78,8 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src,
auto cublas = cublas_handle(this->handle()); auto cublas = cublas_handle(this->handle());
auto one = handle->one_device(); auto one = handle->one_device();
auto zero = handle->zero_device(); auto zero = handle->zero_device();
size_t src_batch_strd = src.layout.stride[0];
size_t dst_batch_strd = dst.layout.stride[0];
if (use_cuda_convnet(src.layout, filter.layout, dst.layout)) { if (use_cuda_convnet(src.layout, filter.layout, dst.layout)) {
local::forward_proxy_convnet(src.ptr<dt_float32>(), local::forward_proxy_convnet(src.ptr<dt_float32>(),
filter.ptr<dt_float32>(), filter.ptr<dt_float32>(),
...@@ -87,7 +89,7 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, ...@@ -87,7 +89,7 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src,
IC, IH, IW, IC, IH, IW,
OC, OH, OW, OC, OH, OW,
FH, FW, FH, FW,
IC*IH*IW, OC*OH*OW, src_batch_strd, dst_batch_strd,
param().pad_h, param().pad_w, param().pad_h, param().pad_w,
param().stride_h, param().stride_w, param().stride_h, param().stride_w,
cublas, stream, cublas, stream,
...@@ -105,7 +107,7 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src, ...@@ -105,7 +107,7 @@ void LocalForwardImpl::exec(_megdnn_tensor_in src,
IC, IH, IW, IC, IH, IW,
OC, OH, OW, OC, OH, OW,
FH, FW, FH, FW,
IC*IH*IW, OC*OH*OW, src_batch_strd, dst_batch_strd,
param().pad_h, param().pad_w, param().pad_h, param().pad_w,
param().stride_h, param().stride_w, param().stride_h, param().stride_w,
is_xcorr, is_xcorr,
...@@ -124,12 +126,14 @@ size_t LocalForwardImpl::get_workspace_in_bytes(const TensorLayout &src, ...@@ -124,12 +126,14 @@ size_t LocalForwardImpl::get_workspace_in_bytes(const TensorLayout &src,
FH = filter.shape[3], FW = filter.shape[4]; FH = filter.shape[3], FW = filter.shape[4];
auto PH = param().pad_h, PW = param().pad_w, auto PH = param().pad_h, PW = param().pad_w,
SH = param().stride_h, SW = param().stride_w; SH = param().stride_h, SW = param().stride_w;
size_t src_batch_strd = src.stride[0];
size_t dst_batch_strd = dst.stride[0];
if (use_cuda_convnet(src, filter, dst)) { if (use_cuda_convnet(src, filter, dst)) {
res = local::get_workspace_in_floats_forward_proxy_convnet(N, res = local::get_workspace_in_floats_forward_proxy_convnet(N,
IC, IH, IW, IC, IH, IW,
OC, OH, OW, OC, OH, OW,
FH, FW, FH, FW,
IC*IH*IW, OC*OH*OW, src_batch_strd, dst_batch_strd,
PH, PW, PH, PW,
SH, SW) * sizeof(dt_float32); SH, SW) * sizeof(dt_float32);
} else { } else {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册