提交 9f352b1c 编写于 作者: M Megvii Engine Team 提交者: Xu Xinran

feat(megbrain/dnn): add indexing remap int32 for naive and cuda

GitOrigin-RevId: 5f66d51de4751d77fc05e2849388fc6dfbae4a53
上级 5dbf218d
......@@ -32,9 +32,11 @@ void IndexingRemapBase::check_layout_fwd(const TensorLayout &src,
}
megdnn_assert(map.shape[dst.ndim] == src.ndim, "%s", errmsg_c);
megdnn_assert(src.dtype == dtype::Float32());
megdnn_assert(dst.dtype == src.dtype);
megdnn_assert(src.dtype == dtype::Float32() || src.dtype == dtype::Int32(),
"indexing remap only support float32/int32, got %s",
src.dtype.name());
megdnn_assert(map.dtype == dtype::Int32());
megdnn_assert(dst.dtype == dtype::Float32());
}
void IndexingRemapForward::deduce_layout(const TensorLayout &src,
......
......@@ -36,13 +36,23 @@ void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src,
for (size_t i = 0_z; i < dst.layout.ndim; ++i) {
dshape.data[i] = dst.layout.shape[i];
}
// Invoke kernel
tensor_remap::forward(src.ptr<dt_float32>(),
map.ptr<dt_int32>(),
dst.ptr<dt_float32>(),
src.layout.ndim, dst.layout.ndim,
sstride, dstride, dshape,
cuda_stream(handle()));
// Invoke kernel
#define cb(dt) \
if (src.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \
using ctype = DTypeTrait<dt>::ctype; \
tensor_remap::forward<ctype>(src.ptr<ctype>(), map.ptr<dt_int32>(), \
dst.ptr<ctype>(), src.layout.ndim, \
dst.layout.ndim, sstride, dstride, \
dshape, cuda_stream(handle())); \
return; \
}
cb(dtype::Float32)
cb(dtype::Int32)
#undef cb
megdnn_throw(
ssprintf("cuda indexing remap forward only support "
"float32/int32 dtype, got %s",
src.layout.dtype.name()));
}
void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff,
......@@ -69,18 +79,27 @@ void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff,
for (size_t i = 0_z; i < diff.layout.ndim; ++i) {
dshape.data[i] = diff.layout.shape[i];
}
// Invoke kernel
tensor_remap::backward(diff.ptr<dt_float32>(),
map.ptr<dt_int32>(),
grad.ptr<dt_float32>(),
grad.layout.ndim, diff.layout.ndim,
sstride, dstride, sshape, dshape,
param().is_non_overlapping,
cuda_stream(handle()));
// Invoke kernel
#define cb(dt) \
if (diff.layout.dtype.enumv() == DTypeTrait<dt>::enumv) { \
using ctype = DTypeTrait<dt>::ctype; \
tensor_remap::backward<ctype>( \
diff.ptr<ctype>(), map.ptr<dt_int32>(), grad.ptr<ctype>(), \
grad.layout.ndim, diff.layout.ndim, sstride, dstride, sshape, \
dshape, param().is_non_overlapping, cuda_stream(handle())); \
return; \
}
cb(dtype::Float32)
cb(dtype::Int32)
megdnn_throw(
ssprintf("cuda indexing remap forward only support "
"float32/int32 dtype, got %s",
diff.layout.dtype.name()));
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -6,28 +6,29 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/cuda/tensor_remap/tensor_remap.cuh"
#include "src/cuda/query_blocksize.cuh"
#include "src/cuda/tensor_remap/tensor_remap.cuh"
namespace megdnn {
namespace cuda {
namespace tensor_remap {
namespace {
__global__ void forward_kernel(const float *src, const int *map, float *dst,
uint32_t sdim, uint32_t ddim,
array_wrapper<int, MEGDNN_MAX_NDIM> sstride,
array_wrapper<int, MEGDNN_MAX_NDIM> dstride,
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape,
uint32_t total)
{
template <typename ctype>
__global__ void forward_kernel(const ctype* src, const int* map, ctype* dst,
uint32_t sdim, uint32_t ddim,
array_wrapper<int, MEGDNN_MAX_NDIM> sstride,
array_wrapper<int, MEGDNN_MAX_NDIM> dstride,
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape,
uint32_t total) {
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x;
if (didx_cont < total) {
uint32_t midx = didx_cont * sdim;
uint32_t didx = 0u;
for (uint32_t j = ddim; j > 0u; --j) {
uint32_t i = j-1u;
uint32_t i = j - 1u;
uint32_t didx_cur = didx_cont % dshape.data[i];
didx_cont /= dshape.data[i];
didx += didx_cur * dstride.data[i];
......@@ -41,34 +42,16 @@ __global__ void forward_kernel(const float *src, const int *map, float *dst,
}
}
void forward(const float *src, const int *map, float *dst,
uint32_t sdim, uint32_t ddim,
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride,
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape,
cudaStream_t stream)
{
uint32_t total = 1u;
for (uint32_t i = 0u; i < ddim; ++i) total *= dshape.data[i];
uint32_t threads = query_blocksize_for_kernel((void *)&forward_kernel);
uint32_t blocks = DIVUP(total, threads);
forward_kernel<<<blocks, threads, 0, stream>>>(src, map, dst,
sdim, ddim,
sstride, dstride, dshape,
total);
after_kernel_launch();
}
__global__ void fill_zero_kernel(float *a, uint32_t dim,
array_wrapper<int, MEGDNN_MAX_NDIM> stride,
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> shape,
uint32_t total)
{
template <typename ctype>
__global__ void fill_zero_kernel(ctype* a, uint32_t dim,
array_wrapper<int, MEGDNN_MAX_NDIM> stride,
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> shape,
uint32_t total) {
uint32_t idx_cont = threadIdx.x + blockIdx.x * blockDim.x;
if (idx_cont < total) {
uint32_t idx = 0u;
for (uint32_t j = dim; j > 0u; --j) {
uint32_t i = j-1u;
uint32_t i = j - 1u;
uint32_t idx_cur = idx_cont % shape.data[i];
idx_cont /= shape.data[i];
idx += idx_cur * stride.data[i];
......@@ -77,19 +60,19 @@ __global__ void fill_zero_kernel(float *a, uint32_t dim,
}
}
__global__ void backward_kernel(const float *diff, const int *map, float *grad,
uint32_t sdim, uint32_t ddim,
array_wrapper<int, MEGDNN_MAX_NDIM> sstride,
array_wrapper<int, MEGDNN_MAX_NDIM> dstride,
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape,
uint32_t total)
{
template <typename ctype>
__global__ void backward_kernel(const ctype* diff, const int* map, ctype* grad,
uint32_t sdim, uint32_t ddim,
array_wrapper<int, MEGDNN_MAX_NDIM> sstride,
array_wrapper<int, MEGDNN_MAX_NDIM> dstride,
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape,
uint32_t total) {
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x;
if (didx_cont < total) {
uint32_t midx = didx_cont * sdim;
uint32_t didx = 0u;
for (uint32_t j = ddim; j > 0u; --j) {
uint32_t i = j-1u;
uint32_t i = j - 1u;
uint32_t didx_cur = didx_cont % dshape.data[i];
didx_cont /= dshape.data[i];
didx += didx_cur * dstride.data[i];
......@@ -103,20 +86,18 @@ __global__ void backward_kernel(const float *diff, const int *map, float *grad,
}
}
template <typename ctype>
__global__ void backward_kernel_non_overlapping(
const float *diff, const int *map, float *grad,
uint32_t sdim, uint32_t ddim,
array_wrapper<int, MEGDNN_MAX_NDIM> sstride,
const ctype* diff, const int* map, ctype* grad, uint32_t sdim,
uint32_t ddim, array_wrapper<int, MEGDNN_MAX_NDIM> sstride,
array_wrapper<int, MEGDNN_MAX_NDIM> dstride,
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape,
uint32_t total)
{
array_wrapper<uint32_t, MEGDNN_MAX_NDIM> dshape, uint32_t total) {
uint32_t didx_cont = threadIdx.x + blockIdx.x * blockDim.x;
if (didx_cont < total) {
uint32_t midx = didx_cont * sdim;
uint32_t didx = 0u;
for (uint32_t j = ddim; j > 0u; --j) {
uint32_t i = j-1u;
uint32_t i = j - 1u;
uint32_t didx_cur = didx_cont % dshape.data[i];
didx_cont /= dshape.data[i];
didx += didx_cur * dstride.data[i];
......@@ -130,55 +111,91 @@ __global__ void backward_kernel_non_overlapping(
}
}
void backward(const float *diff, const int *map, float *grad,
uint32_t sdim, uint32_t ddim,
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride,
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &sshape,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape,
bool is_non_overlapping,
cudaStream_t stream)
{
} // anonymous namespace
namespace tensor_remap {
template <typename ctype>
void forward(const ctype* src, const int* map, ctype* dst, uint32_t sdim,
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride,
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape,
cudaStream_t stream) {
uint32_t total = 1u;
for (uint32_t i = 0u; i < ddim; ++i)
total *= dshape.data[i];
uint32_t threads =
query_blocksize_for_kernel((void*)&forward_kernel<ctype>);
uint32_t blocks = DIVUP(total, threads);
forward_kernel<ctype><<<blocks, threads, 0, stream>>>(
src, map, dst, sdim, ddim, sstride, dstride, dshape, total);
after_kernel_launch();
}
template <typename ctype>
void backward(const ctype* diff, const int* map, ctype* grad, uint32_t sdim,
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride,
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape,
bool is_non_overlapping, cudaStream_t stream) {
{
// Fill grad with zeros.
uint32_t total = 1u;
for (uint32_t i = 0u; i < sdim; ++i) total *= sshape.data[i];
uint32_t threads = query_blocksize_for_kernel((void *)&fill_zero_kernel);
for (uint32_t i = 0u; i < sdim; ++i)
total *= sshape.data[i];
uint32_t threads =
query_blocksize_for_kernel((void*)&fill_zero_kernel<ctype>);
uint32_t blocks = DIVUP(total, threads);
fill_zero_kernel<<<blocks, threads, 0, stream>>>(
fill_zero_kernel<ctype><<<blocks, threads, 0, stream>>>(
grad, sdim, sstride, sshape, total);
after_kernel_launch();
}
{
// Update grad.
uint32_t total = 1u;
for (uint32_t i = 0u; i < ddim; ++i) total *= dshape.data[i];
for (uint32_t i = 0u; i < ddim; ++i)
total *= dshape.data[i];
if (is_non_overlapping) {
uint32_t threads = query_blocksize_for_kernel(
(void *)&backward_kernel_non_overlapping);
(void*)&backward_kernel_non_overlapping<ctype>);
uint32_t blocks = DIVUP(total, threads);
backward_kernel_non_overlapping<<<blocks, threads, 0, stream>>>(
diff, map, grad,
sdim, ddim,
sstride, dstride, dshape,
total);
backward_kernel_non_overlapping<ctype>
<<<blocks, threads, 0, stream>>>(diff, map, grad, sdim,
ddim, sstride, dstride,
dshape, total);
} else {
uint32_t threads = query_blocksize_for_kernel(
(void *)&backward_kernel);
uint32_t threads =
query_blocksize_for_kernel((void*)&backward_kernel<ctype>);
uint32_t blocks = DIVUP(total, threads);
backward_kernel<<<blocks, threads, 0, stream>>>(diff, map, grad,
sdim, ddim,
sstride, dstride, dshape,
backward_kernel<ctype><<<blocks, threads, 0, stream>>>(
diff, map, grad, sdim, ddim, sstride, dstride, dshape,
total);
}
after_kernel_launch();
}
}
} // namespace tensor_remap
} // namespace cuda
} // namespace megdnn
#define INST(T) \
template void forward<T>( \
const T* src, const int* map, T* dst, uint32_t sdim, \
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \
cudaStream_t stream); \
template void backward<T>( \
const T* diff, const int* map, T* grad, uint32_t sdim, \
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride, \
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride, \
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape, \
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape, \
bool is_non_overlapping, cudaStream_t stream);
INST(dt_float32)
INST(dt_int32)
// vim: syntax=cpp.doxygen
#undef INST
} // namespace tensor_remap
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -17,25 +17,23 @@ namespace megdnn {
namespace cuda {
namespace tensor_remap {
void forward(const float *src, const int *map, float *dst,
uint32_t sdim, uint32_t ddim,
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride,
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape,
cudaStream_t stream);
template <typename ctype>
void forward(const ctype* src, const int* map, ctype* dst, uint32_t sdim,
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride,
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape,
cudaStream_t stream);
void backward(const float *diff, const int *map, float *grad,
uint32_t sdim, uint32_t ddim,
const array_wrapper<int, MEGDNN_MAX_NDIM> &sstride,
const array_wrapper<int, MEGDNN_MAX_NDIM> &dstride,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &sshape,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM> &dshape,
bool is_non_overlapping,
cudaStream_t stream);
template <typename ctype>
void backward(const ctype* diff, const int* map, ctype* grad, uint32_t sdim,
uint32_t ddim, const array_wrapper<int, MEGDNN_MAX_NDIM>& sstride,
const array_wrapper<int, MEGDNN_MAX_NDIM>& dstride,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& sshape,
const array_wrapper<uint32_t, MEGDNN_MAX_NDIM>& dshape,
bool is_non_overlapping, cudaStream_t stream);
} // namespace tensor_remap
} // namespace tensor_remap
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -6,75 +6,107 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/tensor_remap/opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
namespace megdnn {
namespace naive {
using namespace megdnn;
using namespace naive;
namespace {
template <typename ctype>
void forward(const TensorND& src, const TensorND& map, const TensorND& dst) {
auto&& sshape = src.layout;
auto&& mshape = map.layout;
auto&& dshape = dst.layout;
// Last element is zero to facilitate maddr calculation.
std::vector<size_t> didx(dshape.ndim + 1, 0_z);
do {
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim);
std::vector<size_t> sidx(sshape.ndim);
for (size_t i = 0_z; i < sshape.ndim; ++i) {
sidx[i] = map.ptr<dt_int32>()[maddr + i];
}
auto saddr = get_linear_addr_noncont(sidx.data(), src.layout);
auto daddr = get_linear_addr_noncont(didx.data(), dst.layout);
dst.ptr<ctype>()[daddr] = src.ptr<ctype>()[saddr];
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim));
}
template <typename ctype>
void backward(const TensorND& diff, const TensorND& map, const TensorND& grad) {
auto&& sshape = grad.layout;
auto&& mshape = map.layout;
auto&& dshape = diff.layout;
std::vector<size_t> sidx(sshape.ndim, 0_z);
{
// Set grad to zero.
do {
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout);
grad.ptr<ctype>()[saddr] = 0.0f;
} while (get_next_addr(sidx.data(), sshape.shape, sshape.ndim));
}
std::vector<size_t> didx(dshape.ndim + 1, 0_z);
do {
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim);
std::vector<size_t> sidx(sshape.ndim);
for (size_t i = 0_z; i < sshape.ndim; ++i) {
sidx[i] = map.ptr<dt_int32>()[maddr + i];
}
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout);
auto daddr = get_linear_addr_noncont(didx.data(), diff.layout);
grad.ptr<ctype>()[saddr] += diff.ptr<ctype>()[daddr];
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim));
}
} // anonymous namespace
void IndexingRemapForwardImpl::exec(_megdnn_tensor_in src,
_megdnn_tensor_in map,
_megdnn_tensor_out dst,
_megdnn_workspace workspace)
{
_megdnn_tensor_in map,
_megdnn_tensor_out dst,
_megdnn_workspace workspace) {
check_exec(src.layout, map.layout, dst.layout, workspace.size);
auto kern = [=]() {
auto &&sshape = src.layout;
auto &&mshape = map.layout;
auto &&dshape = dst.layout;
// Last element is zero to facilitate maddr calculation.
std::vector<size_t> didx(dshape.ndim+1, 0_z);
do {
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim);
std::vector<size_t> sidx(sshape.ndim);
for (size_t i = 0_z; i < sshape.ndim; ++i) {
sidx[i] = map.ptr<dt_int32>()[maddr+i];
}
auto saddr = get_linear_addr_noncont(sidx.data(), src.layout);
auto daddr = get_linear_addr_noncont(didx.data(), dst.layout);
dst.ptr<dt_float32>()[daddr] = src.ptr<dt_float32>()[saddr];
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim));
};
MEGDNN_DISPATCH_CPU_KERN_OPR(kern());
switch (src.layout.dtype.enumv()) {
#define cb(dt) \
case DTypeTrait<dt>::enumv: \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
forward<DTypeTrait<dt>::ctype>(src, map, dst)); \
return;
cb(dtype::Float32)
cb(dtype::Int32)
#undef cb
default:
megdnn_throw(
ssprintf("unsupported dtype %s in indexing "
"remap forward naive\n",
src.layout.dtype.name()));
}
}
void IndexingRemapBackwardImpl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in map,
_megdnn_tensor_out grad,
_megdnn_workspace workspace)
{
_megdnn_tensor_in map,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(diff.layout, map.layout, grad.layout, workspace.size);
auto kern = [=]() {
auto &&sshape = grad.layout;
auto &&mshape = map.layout;
auto &&dshape = diff.layout;
std::vector<size_t> sidx(sshape.ndim, 0_z);
{
// Set grad to zero.
do {
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout);
grad.ptr<dt_float32>()[saddr] = 0.0f;
} while (get_next_addr(sidx.data(), sshape.shape, sshape.ndim));
}
std::vector<size_t> didx(dshape.ndim+1, 0_z);
do {
auto maddr = get_linear_addr(didx.data(), mshape.shape, mshape.ndim);
std::vector<size_t> sidx(sshape.ndim);
for (size_t i = 0_z; i < sshape.ndim; ++i) {
sidx[i] = map.ptr<dt_int32>()[maddr+i];
}
auto saddr = get_linear_addr_noncont(sidx.data(), grad.layout);
auto daddr = get_linear_addr_noncont(didx.data(), diff.layout);
grad.ptr<dt_float32>()[saddr] += diff.ptr<dt_float32>()[daddr];
} while (get_next_addr(didx.data(), dshape.shape, dshape.ndim));
};
MEGDNN_DISPATCH_CPU_KERN_OPR(kern());
switch (diff.layout.dtype.enumv()) {
#define cb(dt) \
case DTypeTrait<dt>::enumv: \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
backward<DTypeTrait<dt>::ctype>(diff, map, grad)); \
return;
cb(dtype::Float32)
cb(dtype::Int32)
#undef cb
default:
megdnn_throw(ssprintf(
"unsupported dtype %s in indexing remap backward naive\n",
diff.layout.dtype.name()));
}
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -16,39 +16,42 @@
namespace megdnn {
namespace test {
TEST_F(CUDA, TENSOR_REMAP_FORWARD)
{
TEST_F(CUDA, TENSOR_REMAP_FORWARD) {
Checker<IndexingRemapForward> checker(handle_cuda());
TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7};
checker.set_dtype(1, dtype::Int32());
TensorShape src{11, 13, 17},
map{3, 5, 7, 3},
dst{3, 5, 7};
using namespace tensor_remap;
{
MapRNG rng(src);
checker.set_rng(1, &rng).execs({src, map, {}});
}
{
NonoverlappingMapRNG rng(src);
checker.set_rng(1, &rng).execs({src, map, {}});
for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) {
checker.set_dtype(0, dt);
checker.set_dtype(2, dt);
using namespace tensor_remap;
{
MapRNG rng(src);
checker.set_rng(1, &rng).execs({src, map, {}});
}
{
NonoverlappingMapRNG rng(src);
checker.set_rng(1, &rng).execs({src, map, {}});
}
}
}
TEST_F(CUDA, TENSOR_REMAP_BACKWARD)
{
TEST_F(CUDA, TENSOR_REMAP_BACKWARD) {
Checker<IndexingRemapBackward> checker(handle_cuda());
checker.set_dtype(1, dtype::Int32());
TensorShape src{11, 13, 17},
map{3, 5, 7, 3},
dst{3, 5, 7};
using namespace tensor_remap;
{
MapRNG rng(src);
checker.set_rng(1, &rng).execs({dst, map, src});
}
{
NonoverlappingMapRNG rng(src);
checker.set_rng(1, &rng).execs({dst, map, src});
TensorShape src{11, 13, 17}, map{3, 5, 7, 3}, dst{3, 5, 7};
checker.set_dtype(1, dtype::Int32());
for (auto dt : std::vector<DType>{dtype::Float32(), dtype::Int32()}) {
checker.set_dtype(0, dt);
checker.set_dtype(2, dt);
using namespace tensor_remap;
{
MapRNG rng(src);
checker.set_rng(1, &rng).execs({dst, map, src});
}
{
NonoverlappingMapRNG rng(src);
checker.set_rng(1, &rng).execs({dst, map, src});
}
}
}
......@@ -56,5 +59,3 @@ TEST_F(CUDA, TENSOR_REMAP_BACKWARD)
} // namespace megdnn
// vim: syntax=cpp.doxygen
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册