提交 9f352b1c 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(megbrain/dnn): add indexing remap int32 for naive and cuda

GitOrigin-RevId: 5f66d51de4751d77fc05e2849388fc6dfbae4a53
上级 5dbf218d
...@@ -32,9 +32,11 @@ void IndexingRemapBase::check_layout_fwd(const TensorLayout &src, ...@@ -32,9 +32,11 @@ void IndexingRemapBase::check_layout_fwd(const TensorLayout &src,
} }
megdnn_assert(map.shape[dst.ndim] == src.ndim, "%s", errmsg_c); megdnn_assert(map.shape[dst.ndim] == src.ndim, "%s", errmsg_c);
megdnn_assert(src.dtype == dtype::Float32()); megdnn_assert(dst.dtype == src.dtype);
megdnn_assert(src.dtype == dtype::Float32() || src.dtype == dtype::Int32(),
"indexing remap only support float32/int32, got %s",
src.dtype.name());
megdnn_assert(map.dtype == dtype::Int32()); megdnn_assert(map.dtype == dtype::Int32());
megdnn_assert(dst.dtype == dtype::Float32());
} }
void IndexingRemapForward::deduce_layout(const TensorLayout &src, void IndexingRemapForward::deduce_layout(const TensorLayout &src,
......
...@@ -36,13 +36,23 @@ void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src, ...@@ -36,13 +36,23 @@ void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src,
for (size_t i = 0_z; i < dst.layout.ndim; ++i) { for (size_t i = 0_z; i < dst.layout.ndim; ++i) {
dshape.data[i] = dst.layout.shape[i]; dshape.data[i] = dst.layout.shape[i];
} }
// Invoke kernel // Invoke kernel
tensor_remap::forward(src.ptr<dt_float32>(), #define cb(dt) \
map.ptr<dt_int32>(), if (src.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \
dst.ptr<dt_float32>(), using ctype = DTypeTrait<dt>::ctype; \
src.layout.ndim, dst.layout.ndim, tensor_remap::forward<ctype>(src.ptr<ctype>(), map.ptr<dt_int32>(), \
sstride, dstride, dshape, dst.ptr<ctype>(), src.layout.ndim, \
cuda_stream(handle())); dst.layout.ndim, sstride, dstride, \
dshape, cuda_stream(handle())); \
return; \
}
cb(dtype::Float32)
cb(dtype::Int32)
#undef cb
megdnn_throw(
ssprintf("cuda indexing remap forward only support "
"float32/int32 dtype, got %s",
src.layout.dtype.name()));
} }
void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff,
...@@ -69,18 +79,27 @@ void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, ...@@ -69,18 +79,27 @@ void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff,
for (size_t i = 0_z; i < diff.layout.ndim; ++i) { for (size_t i = 0_z; i < diff.layout.ndim; ++i) {
dshape.data[i] = diff.layout.shape[i]; dshape.data[i] = diff.layout.shape[i];
} }
// Invoke kernel
tensor_remap::backward(diff.ptr<dt_float32>(), // Invoke kernel
map.ptr<dt_int32>(), #define cb(dt) \
grad.ptr<dt_float32>(), if (diff.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \
grad.layout.ndim, diff.layout.ndim, using ctype = DTypeTrait<dt>::ctype; \
sstride, dstride, sshape, dshape, tensor_remap::backward<ctype>( \
param().is_non_overlapping, diff.ptr<ctype>(), map.ptr<dt_int32>(), grad.ptr<ctype>(), \
cuda_stream(handle())); grad.layout.ndim, diff.layout.ndim, sstride, dstride, sshape, \
dshape, param().is_non_overlapping, cuda_stream(handle())); \
return; \
}
cb(dtype::Float32)
cb(dtype::Int32)
megdnn_throw(
ssprintf("cuda indexing remap forward only support "
"float32/int32 dtype, got %s",
diff.layout.dtype.name()));
} }
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -6,28 +6,29 @@ ...@@ -6,28 +6,29 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/cuda/tensor_remap/tensor_remap.cuh"
#include "src/cuda/query_blocksize.cuh" #include "src/cuda/query_blocksize.cuh"
#include "src/cuda/tensor_remap/tensor_remap.cuh"
namespace megdnn { namespace megdnn {
namespace cuda { namespace cuda {
namespace tensor_remap { namespace {
__global__ void forward_kernel(const float *src, const int *map, float *dst, template <typename ctype>
uint32_t sdim, uint32_t ddim, __global__ void forward_kernel(const ctype* src, const int* map, ctype* dst,
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, uint32_t sdim, uint32_t ddim,
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, array_wrapper<int, MEGDNN_MAX_NDIM> sstride,
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, array_wrapper<int, MEGDNN_MAX_NDIM> dstride,
uint32_t total) array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape,
{ uint32_t total) {
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x;
if (didx_cont < total) { if (didx_cont < total) {
uint32_t midx = didx_cont * sdim; uint32_t midx = didx_cont * sdim;
uint32_t didx = 0u; uint32_t didx = 0u;
for (uint32_t j = ddim; j > 0u; --j) { for (uint32_t j = ddim; j > 0u; --j) {
uint32_t i = j-1u; uint32_t i = j - 1u;
uint32_t didx_cur = didx_cont % dshape.data[i]; uint32_t didx_cur = didx_cont % dshape.data[i];
didx_cont /= dshape.data[i]; didx_cont /= dshape.data[i];
didx += didx_cur * dstride.data[i]; didx += didx_cur * dstride.data[i];
...@@ -41,34 +42,16 @@ __global__ void forward_kernel(const float *src, const int *map, float *dst, ...@@ -41,34 +42,16 @@ __global__ void forward_kernel(const float *src, const int *map, float *dst,
} }
} }
void forward(const float *src, const int *map, float *dst, template <typename ctype>
uint32_t sdim, uint32_t ddim, __global__ void fill_zero_kernel(ctype* a, uint32_t dim,
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, array_wrapper<int, MEGDNN_MAX_NDIM> stride,
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, array_wrapper<uint32_t, MEGDNN_MAX_NDIM> shape,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, uint32_t total) {
cudaStream_t stream)
{
uint32_t total = 1u;
for (uint32_t i = 0u; i < ddim; ++i) total *= dshape.data[i];
uint32_t threads = query_blocksize_for_kernel((void *)&forward_kernel);
uint32_t blocks = DIVUP(total, threads);
forward_kernel<<<blocks, threads, 0, stream>>>(src, map, dst,
sdim, ddim,
sstride, dstride, dshape,
total);
after_kernel_launch();
}
__global__ void fill_zero_kernel(float *a, uint32_t dim,
array_wrapper<int, MEGDNN_MAX_NDIM> stride,
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> shape,
uint32_t total)
{
uint32_t idx_cont = threadIdx.x + blockIdx.x * blockDim.x; uint32_t idx_cont = threadIdx.x + blockIdx.x * blockDim.x;
if (idx_cont < total) { if (idx_cont < total) {
uint32_t idx = 0u; uint32_t idx = 0u;
for (uint32_t j = dim; j > 0u; --j) { for (uint32_t j = dim; j > 0u; --j) {
uint32_t i = j-1u; uint32_t i = j - 1u;
uint32_t idx_cur = idx_cont % shape.data[i]; uint32_t idx_cur = idx_cont % shape.data[i];
idx_cont /= shape.data[i]; idx_cont /= shape.data[i];
idx += idx_cur * stride.data[i]; idx += idx_cur * stride.data[i];
...@@ -77,19 +60,19 @@ __global__ void fill_zero_kernel(float *a, uint32_t dim, ...@@ -77,19 +60,19 @@ __global__ void fill_zero_kernel(float *a, uint32_t dim,
} }
} }
__global__ void backward_kernel(const float *diff, const int *map, float *grad, template <typename ctype>
uint32_t sdim, uint32_t ddim, __global__ void backward_kernel(const ctype* diff, const int* map, ctype* grad,
array_wrapper<int, MEGDNN_MAX_NDIM> sstride, uint32_t sdim, uint32_t ddim,
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, array_wrapper<int, MEGDNN_MAX_NDIM> sstride,
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, array_wrapper<int, MEGDNN_MAX_NDIM> dstride,
uint32_t total) array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape,
{ uint32_t total) {
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x;
if (didx_cont < total) { if (didx_cont < total) {
uint32_t midx = didx_cont * sdim; uint32_t midx = didx_cont * sdim;
uint32_t didx = 0u; uint32_t didx = 0u;
for (uint32_t j = ddim; j > 0u; --j) { for (uint32_t j = ddim; j > 0u; --j) {
uint32_t i = j-1u; uint32_t i = j - 1u;
uint32_t didx_cur = didx_cont % dshape.data[i]; uint32_t didx_cur = didx_cont % dshape.data[i];
didx_cont /= dshape.data[i]; didx_cont /= dshape.data[i];
didx += didx_cur * dstride.data[i]; didx += didx_cur * dstride.data[i];
...@@ -103,20 +86,18 @@ __global__ void backward_kernel(const float *diff, const int *map, float *grad, ...@@ -103,20 +86,18 @@ __global__ void backward_kernel(const float *diff, const int *map, float *grad,
} }
} }
template <typename ctype>
__global__ void backward_kernel_non_overlapping( __global__ void backward_kernel_non_overlapping(
const float *diff, const int *map, float *grad, const ctype* diff, const int* map, ctype* grad, uint32_t sdim,
uint32_t sdim, uint32_t ddim, uint32_t ddim, array_wrapper<int, MEGDNN_MAX_NDIM> sstride,
array_wrapper<int, MEGDNN_MAX_NDIM> sstride,
array_wrapper<int, MEGDNN_MAX_NDIM> dstride, array_wrapper<int, MEGDNN_MAX_NDIM> dstride,
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, uint32_t total) {
uint32_t total)
{
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x; uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x;
if (didx_cont < total) { if (didx_cont < total) {
uint32_t midx = didx_cont * sdim; uint32_t midx = didx_cont * sdim;
uint32_t didx = 0u; uint32_t didx = 0u;
for (uint32_t j = ddim; j > 0u; --j) { for (uint32_t j = ddim; j > 0u; --j) {
uint32_t i = j-1u; uint32_t i = j - 1u;
uint32_t didx_cur = didx_cont % dshape.data[i]; uint32_t didx_cur = didx_cont % dshape.data[i];
didx_cont /= dshape.data[i]; didx_cont /= dshape.data[i];
didx += didx_cur * dstride.data[i]; didx += didx_cur * dstride.data[i];
...@@ -130,55 +111,91 @@ __global__ void backward_kernel_non_overlapping( ...@@ -130,55 +111,91 @@ __global__ void backward_kernel_non_overlapping(
} }
} }
void backward(const float *diff, const int *map, float *grad, } // anonymous namespace
uint32_t sdim, uint32_t ddim,
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, namespace tensor_remap {
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, template <typename ctype>
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &sshape, void forward(const ctype* src, const int* map, ctype* dst, uint32_t sdim,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride,
bool is_non_overlapping, const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride,
cudaStream_t stream) const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape,
{ cudaStream_t stream) {
uint32_t total = 1u;
for (uint32_t i = 0u; i < ddim; ++i)
total *= dshape.data[i];
uint32_t threads =
query_blocksize_for_kernel((void*)&forward_kernel<ctype>);
uint32_t blocks = DIVUP(total, threads);
forward_kernel<ctype><<<blocks, threads, 0, stream>>>(
src, map, dst, sdim, ddim, sstride, dstride, dshape, total);
after_kernel_launch();
}
template <typename ctype>
void backward(const ctype* diff, const int* map, ctype* grad, uint32_t sdim,
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride,
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape,
bool is_non_overlapping, cudaStream_t stream) {
{ {
// Fill grad with zeros. // Fill grad with zeros.
uint32_t total = 1u; uint32_t total = 1u;
for (uint32_t i = 0u; i < sdim; ++i) total *= sshape.data[i]; for (uint32_t i = 0u; i < sdim; ++i)
uint32_t threads = query_blocksize_for_kernel((void *)&fill_zero_kernel); total *= sshape.data[i];
uint32_t threads =
query_blocksize_for_kernel((void*)&fill_zero_kernel<ctype>);
uint32_t blocks = DIVUP(total, threads); uint32_t blocks = DIVUP(total, threads);
fill_zero_kernel<<<blocks, threads, 0, stream>>>( fill_zero_kernel<ctype><<<blocks, threads, 0, stream>>>(
grad, sdim, sstride, sshape, total); grad, sdim, sstride, sshape, total);
after_kernel_launch(); after_kernel_launch();
} }
{ {
// Update grad. // Update grad.
uint32_t total = 1u; uint32_t total = 1u;
for (uint32_t i = 0u; i < ddim; ++i) total *= dshape.data[i]; for (uint32_t i = 0u; i < ddim; ++i)
total *= dshape.data[i];
if (is_non_overlapping) { if (is_non_overlapping) {
uint32_t threads = query_blocksize_for_kernel( uint32_t threads = query_blocksize_for_kernel(
(void *)&backward_kernel_non_overlapping); (void*)&backward_kernel_non_overlapping<ctype>);
uint32_t blocks = DIVUP(total, threads); uint32_t blocks = DIVUP(total, threads);
backward_kernel_non_overlapping<<<blocks, threads, 0, stream>>>( backward_kernel_non_overlapping<ctype>
diff, map, grad, <<<blocks, threads, 0, stream>>>(diff, map, grad, sdim,
sdim, ddim, ddim, sstride, dstride,
sstride, dstride, dshape, dshape, total);
total);
} else { } else {
uint32_t threads = query_blocksize_for_kernel( uint32_t threads =
(void *)&backward_kernel); query_blocksize_for_kernel((void*)&backward_kernel<ctype>);
uint32_t blocks = DIVUP(total, threads); uint32_t blocks = DIVUP(total, threads);
backward_kernel<<<blocks, threads, 0, stream>>>(diff, map, grad, backward_kernel<ctype><<<blocks, threads, 0, stream>>>(
sdim, ddim, diff, map, grad, sdim, ddim, sstride, dstride, dshape,
sstride, dstride, dshape,
total); total);
} }
after_kernel_launch(); after_kernel_launch();
} }
} }
} // namespace tensor_remap #define INST(T) \
} // namespace cuda template void forward<T>( \
} // namespace megdnn const T* src, const int* map, T* dst, uint32_t sdim, \
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \
cudaStream_t stream); \
template void backward<T>( \
const T* diff, const int* map, T* grad, uint32_t sdim, \
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, \
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \
bool is_non_overlapping, cudaStream_t stream);
INST(dt_float32)
INST(dt_int32)
// vim: syntax=cpp.doxygen #undef INST
} // namespace tensor_remap
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -17,25 +17,23 @@ namespace megdnn { ...@@ -17,25 +17,23 @@ namespace megdnn {
namespace cuda { namespace cuda {
namespace tensor_remap { namespace tensor_remap {
void forward(const float *src, const int *map, float *dst, template <typename ctype>
uint32_t sdim, uint32_t ddim, void forward(const ctype* src, const int* map, ctype* dst, uint32_t sdim,
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride,
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape,
cudaStream_t stream); cudaStream_t stream);
void backward(const float *diff, const int *map, float *grad, template <typename ctype>
uint32_t sdim, uint32_t ddim, void backward(const ctype* diff, const int* map, ctype* grad, uint32_t sdim,
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride, uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride,
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride, const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &sshape, const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape, const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape,
bool is_non_overlapping, bool is_non_overlapping, cudaStream_t stream);
cudaStream_t stream);
} // namespace tensor_remap } // namespace tensor_remap
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -6,75 +6,107 @@ ...@@ -6,75 +6,107 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/naive/tensor_remap/opr_impl.h" #include "src/naive/tensor_remap/opr_impl.h"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/naive/handle.h" #include "src/naive/handle.h"
namespace megdnn { using namespace megdnn;
namespace naive { using namespace naive;
namespace {
template <typename ctype>
void forward(const TensorND& src, const TensorND& map, const TensorND& dst) {
auto&& sshape = src.layout;
auto&& mshape = map.layout;
auto&& dshape = dst.layout;
// Last element is zero to facilitate maddr calculation.
std::vector<size_t> didx(dshape.ndim + 1, 0_z);
do {
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim);
std::vector<size_t> sidx(sshape.ndim);
for (size_t i = 0_z; i < sshape.ndim; ++i) {
sidx[i] = map.ptr<dt_int32>()[maddr + i];
}
auto saddr = get_linear_addr_noncont(sidx.data(), src.layout);
auto daddr = get_linear_addr_noncont(didx.data(), dst.layout);
dst.ptr<ctype>()[daddr] = src.ptr<ctype>()[saddr];
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim));
}
template <typename ctype>
void backward(const TensorND& diff, const TensorND& map, const TensorND& grad) {
auto&& sshape = grad.layout;
auto&& mshape = map.layout;
auto&& dshape = diff.layout;
std::vector<size_t> sidx(sshape.ndim, 0_z);
{
// Set grad to zero.
do {
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout);
grad.ptr<ctype>()[saddr] = 0.0f;
} while (get_next_addr(sidx.data(), sshape.shape, sshape.ndim));
}
std::vector<size_t> didx(dshape.ndim + 1, 0_z);
do {
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim);
std::vector<size_t> sidx(sshape.ndim);
for (size_t i = 0_z; i < sshape.ndim; ++i) {
sidx[i] = map.ptr<dt_int32>()[maddr + i];
}
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout);
auto daddr = get_linear_addr_noncont(didx.data(), diff.layout);
grad.ptr<ctype>()[saddr] += diff.ptr<ctype>()[daddr];
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim));
}
} // anonymous namespace
void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src, void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in map, _megdnn_tensor_in map,
_megdnn_tensor_out dst, _megdnn_tensor_out dst,
_megdnn_workspace workspace) _megdnn_workspace workspace) {
{
check_exec(src.layout, map.layout, dst.layout, workspace.size); check_exec(src.layout, map.layout, dst.layout, workspace.size);
auto kern = [=]() { switch (src.layout.dtype.enumv()) {
auto &&sshape = src.layout; #define cb(dt) \
auto &&mshape = map.layout; case DTypeTrait<dt>::enumv: \
auto &&dshape = dst.layout; MEGDNN_DISPATCH_CPU_KERN_OPR( \
// Last element is zero to facilitate maddr calculation. forward<DTypeTrait<dt>::ctype>(src, map, dst)); \
std::vector<size_t> didx(dshape.ndim+1, 0_z); return;
do { cb(dtype::Float32)
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim); cb(dtype::Int32)
std::vector<size_t> sidx(sshape.ndim); #undef cb
for (size_t i = 0_z; i < sshape.ndim; ++i) {
sidx[i] = map.ptr<dt_int32>()[maddr+i]; default:
} megdnn_throw(
auto saddr = get_linear_addr_noncont(sidx.data(), src.layout); ssprintf("unsupported dtype %s in indexing "
auto daddr = get_linear_addr_noncont(didx.data(), dst.layout); "remap forward naive\n",
dst.ptr<dt_float32>()[daddr] = src.ptr<dt_float32>()[saddr]; src.layout.dtype.name()));
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim)); }
};
MEGDNN_DISPATCH_CPU_KERN_OPR(kern());
} }
void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff, void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in map, _megdnn_tensor_in map,
_megdnn_tensor_out grad, _megdnn_tensor_out grad,
_megdnn_workspace workspace) _megdnn_workspace workspace) {
{
check_exec(diff.layout, map.layout, grad.layout, workspace.size); check_exec(diff.layout, map.layout, grad.layout, workspace.size);
auto kern = [=]() { switch (diff.layout.dtype.enumv()) {
auto &&sshape = grad.layout; #define cb(dt) \
auto &&mshape = map.layout; case DTypeTrait<dt>::enumv: \
auto &&dshape = diff.layout; MEGDNN_DISPATCH_CPU_KERN_OPR( \
std::vector<size_t> sidx(sshape.ndim, 0_z); backward<DTypeTrait<dt>::ctype>(diff, map, grad)); \
{ return;
// Set grad to zero. cb(dtype::Float32)
do { cb(dtype::Int32)
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout); #undef cb
grad.ptr<dt_float32>()[saddr] = 0.0f; default:
} while (get_next_addr(sidx.data(), sshape.shape, sshape.ndim)); megdnn_throw(ssprintf(
} "unsupported dtype %s in indexing remap backward naive\n",
std::vector<size_t> didx(dshape.ndim+1, 0_z); diff.layout.dtype.name()));
do { }
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim);
std::vector<size_t> sidx(sshape.ndim);
for (size_t i = 0_z; i < sshape.ndim; ++i) {
sidx[i] = map.ptr<dt_int32>()[maddr+i];
}
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout);
auto daddr = get_linear_addr_noncont(didx.data(), diff.layout);
grad.ptr<dt_float32>()[saddr] += diff.ptr<dt_float32>()[daddr];
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim));
};
MEGDNN_DISPATCH_CPU_KERN_OPR(kern());
} }
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -16,39 +16,42 @@ ...@@ -16,39 +16,42 @@
namespace megdnn { namespace megdnn {
namespace test { namespace test {
TEST_F(CUDA, TENSOR_REMAP_FORWARD) TEST_F(CUDA, TENSOR_REMAP_FORWARD) {
{
Checker<IndexingRemapForward> checker(handle_cuda()); Checker<IndexingRemapForward> checker(handle_cuda());
TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7};
checker.set_dtype(1, dtype::Int32()); checker.set_dtype(1, dtype::Int32());
TensorShape src{11, 13, 17}, for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) {
map{3, 5, 7, 3}, checker.set_dtype(0, dt);
dst{3, 5, 7}; checker.set_dtype(2, dt);
using namespace tensor_remap; using namespace tensor_remap;
{ {
MapRNG rng(src); MapRNG rng(src);
checker.set_rng(1, &rng).execs({src, map, {}}); checker.set_rng(1, &rng).execs({src, map, {}});
} }
{ {
NonoverlappingMapRNG rng(src); NonoverlappingMapRNG rng(src);
checker.set_rng(1, &rng).execs({src, map, {}}); checker.set_rng(1, &rng).execs({src, map, {}});
}
} }
} }
TEST_F(CUDA, TENSOR_REMAP_BACKWARD) TEST_F(CUDA, TENSOR_REMAP_BACKWARD) {
{
Checker<IndexingRemapBackward> checker(handle_cuda()); Checker<IndexingRemapBackward> checker(handle_cuda());
checker.set_dtype(1, dtype::Int32()); checker.set_dtype(1, dtype::Int32());
TensorShape src{11, 13, 17}, TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7};
map{3, 5, 7, 3}, checker.set_dtype(1, dtype::Int32());
dst{3, 5, 7}; for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) {
using namespace tensor_remap; checker.set_dtype(0, dt);
{ checker.set_dtype(2, dt);
MapRNG rng(src); using namespace tensor_remap;
checker.set_rng(1, &rng).execs({dst, map, src}); {
} MapRNG rng(src);
{ checker.set_rng(1, &rng).execs({dst, map, src});
NonoverlappingMapRNG rng(src); }
checker.set_rng(1, &rng).execs({dst, map, src}); {
NonoverlappingMapRNG rng(src);
checker.set_rng(1, &rng).execs({dst, map, src});
}
} }
} }
...@@ -56,5 +59,3 @@ TEST_F(CUDA, TENSOR_REMAP_BACKWARD) ...@@ -56,5 +59,3 @@ TEST_F(CUDA, TENSOR_REMAP_BACKWARD)
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册