提交 3977b7aa 编写于 作者: M Megvii Engine Team

feat(mgb/shuffle): add shuffle opr

GitOrigin-RevId: 80490a6f848d524111bee097f11b591b5a3956c8
上级 17371e79
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megdnn/internal/opr_header_prologue.h"
......@@ -94,6 +95,42 @@ class PermutationRNG: public RNGBase {
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
};
class ShuffleRNGForward : public OperatorBase {
DEF_OPR_IMPL(ShuffleRNGForward, OperatorBase, 1, 2);
DEF_OPR_PARAM(ShuffleRNG);
public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst,
TensorLayout& indices);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& indices) = 0;
protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& indices, size_t workspace_in_bytes);
};
using ShuffleRNG = ShuffleRNGForward;
class ShuffleRNGBackward : public OperatorBase {
DEF_OPR_IMPL(ShuffleRNGBackward, OperatorBase, 2, 1);
DEF_OPR_PARAM(ShuffleRNG);
public:
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
const TensorLayout& indices,
const TensorLayout& grad) = 0;
protected:
void check_exec(const TensorLayout& diff, const TensorLayout& indices,
const TensorLayout& grad, size_t workspace_in_bytes);
};
/*!
* \brief sleep for specific time on the computing device; useful for testing
* async problems
......
......@@ -781,6 +781,9 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'Float32 are supported.'),
'DTypeEnum::Int32'))
(pdef('ShuffleRNG').
add_fields('uint64', 'seed', 0))
(pdef('Flip').
add_fields('bool', 'vertical', 'false', 'horizontal', 'false'))
......
......@@ -165,6 +165,8 @@ private:
cb(BetaRNG) \
cb(PoissonRNG) \
cb(PermutationRNG) \
cb(ShuffleRNGForward) \
cb(ShuffleRNGBackward) \
cb(SeparableConvForward) \
cb(SeparableFilterForward) \
cb(BNForward) \
......
......@@ -128,6 +128,8 @@ DEF(GammaRNG, 3, true, true);
DEF(BetaRNG, 3, true, true);
DEF(PoissonRNG, 2, true, true);
DEF(PermutationRNG, 1, true, true);
DEF(ShuffleRNGForward, 3, true, true);
DEF(ShuffleRNGBackward, 3, true, false);
DEF(ChecksumForward, 1, true, false);
DEF(CheckHasInf, 2, true, true);
DEF(LSQForward, 5, true, true);
......
......@@ -15,6 +15,47 @@
namespace megdnn {
void ShuffleRNGForward::deduce_layout(const TensorLayout& src,
TensorLayout& dst,
TensorLayout& indices) {
dst = src;
indices = TensorLayout(TensorShape({src.shape[0]}), dtype::Int32());
}
void ShuffleRNGForward::check_exec(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& indices,
size_t workspace_in_bytes) {
TensorLayout dst_expected, indices_expected;
megdnn_assert_contiguous(src);
deduce_layout(src, dst_expected, indices_expected);
megdnn_assert_eq_layout(dst_expected, dst);
megdnn_assert_eq_layout(indices_expected, indices);
megdnn_assert_contiguous(indices);
megdnn_assert(src.dtype == dst.dtype);
megdnn_assert(indices.dtype == dtype::Int32());
auto required_workspace_in_bytes =
get_workspace_in_bytes(src, dst, indices);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void ShuffleRNGBackward::check_exec(const TensorLayout& diff,
const TensorLayout& indices,
const TensorLayout& grad,
size_t workspace_in_bytes) {
megdnn_assert(
diff.shape[0] == indices.shape[0] && diff.dtype == grad.dtype &&
indices.dtype == dtype::Int32{} && diff.is_contiguous() &&
indices.is_contiguous() && grad.is_contiguous(),
"invalid layouts: diff=%s indices=%s grad=%s",
diff.to_string().c_str(), indices.to_string().c_str(),
grad.to_string().c_str());
auto required_workspace_in_bytes =
get_workspace_in_bytes(diff, indices, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void PermutationRNG::check_exec(
const TensorLayout &dst, size_t workspace_in_bytes) {
megdnn_assert((dst.dtype == dtype::Float32() ||
......
......@@ -55,6 +55,42 @@ __global__ void permute_duplicate_keys_kernel(KeyType* keys, ValueType* indexs,
}
}
template <typename T>
__global__ void shuffle_fwd_kernel(uint32_t step, uint32_t src_size, const T* sptr,
T* dptr, const int* iptr) {
uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < src_size) {
uint32_t r = idx / step;
dptr[idx]=sptr[iptr[r] * step + idx % step];
}
}
template <typename T>
void shuffle_forward(T* sptr, T* dptr, dt_int32* iptr,
size_t len, size_t step, cudaStream_t stream) {
uint32_t src_size = len * step;
shuffle_fwd_kernel<<<DIVUP(src_size, 512), 512, 0, stream>>>(
step, src_size, sptr, dptr, iptr);
after_kernel_launch();
}
template <typename T>
__global__ void shuffle_bwd_kernel(uint32_t step, uint32_t src_size, T* sptr,
T* dptr, const int* iptr) {
uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < src_size) {
uint32_t r = idx / step;
sptr[iptr[r] * step + idx % step]=dptr[idx];
}
}
template <typename T>
void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr,
size_t len, size_t step, cudaStream_t stream) {
uint32_t src_size = len * step;
shuffle_bwd_kernel<<<DIVUP(src_size, 512), 512, 0, stream>>>(
step, src_size, sptr, dptr, iptr);
after_kernel_launch();
}
uint32_t get_permutation_bits(size_t N) {
double uniq_rand_num_prob = 0.9;
double thresh = std::log(uniq_rand_num_prob) * 12;
......@@ -156,6 +192,14 @@ INST_PERMUTATION(dt_int16)
INST_PERMUTATION(dt_float32)
#undef INST_PERMUTATION
#define INST_SHUFFLE(T) \
template void shuffle_forward<T>(T* sptr, T* dptr, dt_int32* iptr,\
size_t len, size_t step, cudaStream_t stream);\
template void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr,\
size_t len, size_t step, cudaStream_t stream);
ARGSORT_FOREACH_CTYPE(INST_SHUFFLE)
#undef INST_SHUFFLE
} // namespace random
#define INST(_dtype) \
......
......@@ -253,6 +253,17 @@ void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed
size_t get_permutation_workspace_in_bytes(size_t N);
template<typename T>
void shuffle_forward(T* sptr, T* dptr, dt_int32* iptr,
size_t len, size_t step, cudaStream_t stream);
template<typename T>
void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr,
size_t len, size_t step, cudaStream_t stream);
#define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16))
} // namespace random
} // namespace cuda
} // namespace megdnn
......@@ -9,11 +9,11 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "./opr_impl.h"
#include "./kernel.cuh"
#include "src/common/utils.h"
#include "src/cuda/handle.h"
#include "src/cuda/utils.h"
#include "./opr_impl.h"
#include "./kernel.cuh"
using namespace megdnn;
using namespace cuda;
......@@ -261,5 +261,76 @@ size_t PermutationRNGImpl::get_workspace_in_bytes(const TensorLayout &layout){
return random::get_permutation_workspace_in_bytes(size);
}
ShuffleRNGForwardImpl::ShuffleRNGForwardImpl(Handle* handle)
: ShuffleRNGForward(handle),
m_seed(0),
m_offset(0),
m_stream(cuda_stream(handle)) {}
void ShuffleRNGForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, indices.layout, workspace.size);
ensure_seed(m_param.seed);
auto wk = workspace.ptr<void>();
const auto len = indices.layout[0];
random::permutation_forward<dt_int32>(indices.ptr<dt_int32>(), wk, len,
m_seed, m_offset, m_stream);
size_t step = 0;
for (size_t i = 1; i < src.layout.ndim; ++i) {
step += src.layout[i];
}
if (step <= 0)
step = 1;
switch (src.layout.dtype.enumv()) {
#define cb(DType) \
case DTypeTrait<DType>::enumv: \
random::shuffle_forward<DTypeTrait<DType>::ctype>( \
src.ptr<DTypeTrait<DType>::ctype>(), \
dst.ptr<DTypeTrait<DType>::ctype>(), indices.ptr<dt_int32>(), \
len, step, m_stream); \
break;
ARGSORT_FOREACH_CTYPE(cb)
#undef cb
default : megdnn_throw("bad dtype");
}
m_offset += 8;
}
size_t ShuffleRNGForwardImpl::get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&,
const TensorLayout& indices) {
size_t size = indices.total_nr_elems();
return random::get_permutation_workspace_in_bytes(size);
}
ShuffleRNGBackwardImpl::ShuffleRNGBackwardImpl(Handle* handle)
: ShuffleRNGBackward(handle), m_stream(cuda_stream(handle)) {}
void ShuffleRNGBackwardImpl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in indices,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
const auto len = indices.layout[0];
auto step = 0;
for (size_t i = 1; i < diff.layout.ndim; ++i) {
step += diff.layout[i];
}
if (step <= 0)
step = 1;
switch (diff.layout.dtype.enumv()) {
#define cb(DType) \
case DTypeTrait<DType>::enumv: \
random::shuffle_backward<DTypeTrait<DType>::ctype>( \
diff.ptr<DTypeTrait<DType>::ctype>(), indices.ptr<dt_int32>(), \
grad.ptr<DTypeTrait<DType>::ctype>(), len, step, m_stream); \
break;
ARGSORT_FOREACH_CTYPE(cb)
#undef cb
default:
megdnn_throw("bad dtype");
}
}
// vim: syntax=cpp.doxygen
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
......@@ -152,6 +153,45 @@ public:
}
};
class ShuffleRNGForwardImpl : public ShuffleRNGForward {
uint64_t m_seed, m_offset;
cudaStream_t m_stream;
public:
using ShuffleRNGForward::ShuffleRNGForward;
ShuffleRNGForwardImpl(Handle* handle);
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& indices) override;
void seed(uint64_t seed) { m_seed = seed; }
void ensure_seed(uint64_t seed) {
if (m_seed != seed) {
this->seed(seed);
}
}
};
class ShuffleRNGBackwardImpl : public ShuffleRNGBackward {
cudaStream_t m_stream;
public:
using ShuffleRNGBackward::ShuffleRNGBackward;
ShuffleRNGBackwardImpl(Handle* handle);
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -6,12 +6,13 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "./opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include <cmath>
......@@ -229,7 +230,29 @@ namespace {
}
}
} // anonymous namespace
template <typename T>
void shuffle_fwd(const T* __restrict sptr, T* __restrict dptr,
const dt_int32* iptr, const size_t len,
const size_t step) MEGDNN_NOEXCEPT {
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
dptr[i * step + j] = sptr[iptr[i] * step + j];
}
}
}
template <typename T>
void shuffle_bwd(T* __restrict sptr, const T* __restrict dptr,
const dt_int32* iptr, const size_t len,
const size_t step) MEGDNN_NOEXCEPT {
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
sptr[iptr[i] * step + j] = dptr[i * step + j];
}
}
}
} // anonymous namespace
uint64_t Splitmix64::operator() () {
uint64_t z = (m_s += UINT64_C(0x9E3779B97F4A7C15));
......@@ -394,5 +417,54 @@ void PermutationRNGImpl::exec(
}
}
// vim: syntax=cpp.doxygen
void ShuffleRNGForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, indices.layout, workspace.size);
const auto len = indices.layout[0];
auto iptr = indices.ptr<dt_int32>();
auto prng = &m_rng.ensure_seed(m_param.seed);
fill_permutation<dt_int32>(prng, iptr, len);
auto step = 0;
for (size_t i = 1; i < src.layout.ndim; ++i) {
step += src.layout[i];
}
if (step <= 0)
step = 1;
#define cb(DType) \
if (src.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
shuffle_fwd<T>(src.ptr<T>(), dst.ptr<T>(), iptr, len, step)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}
void ShuffleRNGBackwardImpl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in indices,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(diff.layout, indices.layout, grad.layout, workspace.size);
const auto len = indices.layout[0];
auto iptr = indices.ptr<dt_int32>();
auto step = 0;
for (size_t i = 1; i < diff.layout.ndim; ++i) {
step += diff.layout[i];
}
if (step <= 0)
step = 1;
#define cb(DType) \
if (diff.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(shuffle_bwd<T>( \
grad.ptr<T>(), diff.ptr<T>(), iptr, len, step)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}
// vim: syntax=cpp.doxygen
......@@ -128,6 +128,35 @@ public:
size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; }
};
class ShuffleRNGForwardImpl : public ShuffleRNGForward {
Xoroshiro128plus m_rng;
public:
using ShuffleRNGForward::ShuffleRNGForward;
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
class ShuffleRNGBackwardImpl : public ShuffleRNGBackward {
Xoroshiro128plus m_rng;
public:
using ShuffleRNGBackward::ShuffleRNGBackward;
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -143,6 +143,60 @@ void run_permutation(Handle* handle) {
}
}
template <typename T>
void run_shuffle(Handle* handle, bool bwd_flag) {
using ctype = typename DTypeTrait<T>::ctype;
auto run = [&](TensorShape shape) {
auto opr = handle->create_operator<ShuffleRNGForward>();
TensorLayout srclay{shape, T()};
TensorLayout dstlay{shape, T()};
TensorLayout indexlay{TensorShape{shape[0]}, dtype::Int32()};
Tensor<dt_byte> workspace(
handle, {TensorShape{opr->get_workspace_in_bytes(srclay, dstlay,
indexlay)},
dtype::Byte()});
SyncedTensor<ctype> src(handle, srclay);
SyncedTensor<ctype> dst(handle, dstlay);
SyncedTensor<DTypeTrait<dt_int32>::ctype> index(handle, indexlay);
auto sptr = src.ptr_mutable_host();
size_t size = src.layout().total_nr_elems();
for (size_t j = 0; j < size; ++j) {
sptr[j] = j;
}
opr->exec(src.tensornd_dev(), dst.tensornd_dev(), index.tensornd_dev(),
{workspace.ptr(), workspace.layout().total_nr_elems()});
auto dptr = dst.ptr_mutable_host();
auto iptr = index.ptr_mutable_host();
size_t len = index.layout().total_nr_elems();
size_t step = size / len;
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]);
}
}
if (bwd_flag) {
for (size_t j = 0; j < size; ++j) {
sptr[j] = 0;
}
auto oprbwd = handle->create_operator<ShuffleRNGBackward>();
oprbwd->exec(
dst.tensornd_dev(), index.tensornd_dev(),
src.tensornd_dev(),
{workspace.ptr(), workspace.layout().total_nr_elems()});
auto sptr_bwd = src.ptr_mutable_host();
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
ASSERT_EQ(dptr[i * step + j], sptr_bwd[iptr[i] * step + j]);
}
}
}
};
run({10});
run({6, 3});
}
} // anonymous namespace
TEST_F(CUDA, UNIFORM_RNG_F32) {
......@@ -215,6 +269,30 @@ TEST_F(CUDA, PERMUTATION_RNG_INT16) {
run_permutation<dtype::Int16>(handle_cuda());
}
TEST_F(CUDA, SHUFFLE_RNG_F32) {
run_shuffle<dtype::Float32>(handle_cuda(), false);
}
TEST_F(CUDA, SHUFFLE_RNG_INT32) {
run_shuffle<dtype::Int32>(handle_cuda(), false);
}
TEST_F(CUDA, SHUFFLE_RNG_F16) {
run_shuffle<dtype::Float16>(handle_cuda(), false);
}
TEST_F(CUDA, SHUFFLE_RNG_BWD_F32) {
run_shuffle<dtype::Float32>(handle_cuda(), true);
}
TEST_F(CUDA, SHUFFLE_RNG_BWD_INT32) {
run_shuffle<dtype::Int32>(handle_cuda(), true);
}
TEST_F(CUDA, SHUFFLE_RNG_BWD_F16) {
run_shuffle<dtype::Float16>(handle_cuda(), true);
}
} // namespace test
} // namespace megdnn
......
......@@ -6,12 +6,13 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megdnn.h"
#include "test/naive/fixture.h"
#include "test/naive/rng.h"
#include "megdnn.h"
#include "test/common/tensor.h"
#include "test/naive/fixture.h"
namespace megdnn {
......@@ -181,7 +182,59 @@ namespace {
ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8);
}
}
}
template <typename T>
void run_shuffle(Handle* handle, bool bwd_flag) {
using ctype = typename DTypeTrait<T>::ctype;
auto run = [&](TensorShape shape) {
auto opr = handle->create_operator<ShuffleRNGForward>();
TensorLayout srclay{shape, T()};
TensorLayout dstlay{shape, T()};
TensorLayout indexlay{TensorShape{shape[0]}, dtype::Int32()};
Tensor<dt_byte> workspace(
handle, {TensorShape{opr->get_workspace_in_bytes(srclay, dstlay,
indexlay)},
dtype::Byte()});
Tensor<ctype> src(handle, srclay);
Tensor<ctype> dst(handle, dstlay);
Tensor<DTypeTrait<dt_int32>::ctype> index(handle, indexlay);
auto sptr = src.ptr();
size_t size = src.layout().total_nr_elems();
for (size_t j = 0; j < size; ++j) {
sptr[j] = j;
}
opr->exec(src.tensornd(), dst.tensornd(), index.tensornd(),
{workspace.ptr(), workspace.layout().total_nr_elems()});
auto dptr = dst.ptr();
auto iptr = index.ptr();
size_t len = index.layout().total_nr_elems();
size_t step = size / len;
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]);
}
}
if (bwd_flag) {
for (size_t j = 0; j < size; ++j) {
sptr[j] = 0;
}
auto oprbwd = handle->create_operator<ShuffleRNGBackward>();
oprbwd->exec(
dst.tensornd(), index.tensornd(), src.tensornd(),
{workspace.ptr(), workspace.layout().total_nr_elems()});
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]);
}
}
}
};
run({10});
run({6, 3});
}
} // namespace
TEST_F(NAIVE, UNIFORM_RNG_F32) {
run_uniform<dtype::Float32>(handle());
......@@ -235,10 +288,31 @@ TEST_F(NAIVE, PERMUTATION_RNG_INT16) {
run_permutation<dtype::Int16>(handle());
}
} // namespace test
} // namespace megdnn
TEST_F(NAIVE, SHUFFLE_RNG_FWD_F32) {
run_shuffle<dtype::Float32>(handle(), false);
}
// vim: syntax=cpp.doxygen
TEST_F(NAIVE, SHUFFLE_RNG_FWD_INT32) {
run_shuffle<dtype::Int32>(handle(), false);
}
TEST_F(NAIVE, SHUFFLE_RNG_FWD_F16) {
run_shuffle<dtype::Float16>(handle(), false);
}
TEST_F(NAIVE, SHUFFLE_RNG_BWD_F32) {
run_shuffle<dtype::Float32>(handle(), true);
}
TEST_F(NAIVE, SHUFFLE_RNG_BWD_INT32) {
run_shuffle<dtype::Int32>(handle(), true);
}
TEST_F(NAIVE, SHUFFLE_RNG_BWD_F16) {
run_shuffle<dtype::Float16>(handle(), true);
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -6,7 +6,7 @@
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, uniform
from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, shuffle, uniform
__all__ = [
"RNG",
......@@ -17,6 +17,7 @@ __all__ = [
"poisson",
"seed",
"uniform",
"shuffle",
]
# pylint: disable=undefined-variable
del rng # type: ignore[name-defined]
......@@ -27,6 +27,7 @@ from ..core.ops.builtin import (
GaussianRNG,
PermutationRNG,
PoissonRNG,
ShuffleRNG,
UniformRNG,
)
from ..core.tensor import utils
......@@ -41,6 +42,7 @@ __all__ = [
"beta",
"poisson",
"permutation",
"shuffle",
]
_rng = None
......@@ -219,6 +221,13 @@ def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Ten
return output
def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor:
assert inp.size > 0, "size needs to be greater than 0"
op = ShuffleRNG(seed=seed, handle=handle)
output, _ = apply(op, inp)
inp._reset(output)
class RNG:
r""":class:`RNG` exposes a number of methods for generating random numbers.
......@@ -581,6 +590,45 @@ class RNG:
n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype
)
def shuffle(self, inp: Tensor):
r"""Modify a sequence in-place by shuffling its contents.
This function only shuffles the Tensor along the first axis of a multi-dimensional Tensor.
The order of sub-Tensors is changed but their contents remains the same.
Args:
inp: input tensor.
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.random as rand
x = mge.tensor(np.arange(10))
rand.shuffle(x)
print(x.numpy())
y = mge.tensor(np.arange(18)).reshape(6,3)
rand.shuffle(y)
print(y.numpy())
Outputs:
.. testoutput::
:options: +SKIP
[7 9 3 0 8 2 4 5 6 1]
[[12. 13. 14.]
[ 3. 4. 5.]
[15. 16. 17.]
[ 0. 1. 2.]
[ 9. 10. 11.]
[ 6. 7. 8.]]
"""
_seed = self._seed() if callable(self._seed) else self._seed
_shuffle(inp=inp, seed=_seed, handle=self._handle)
def __del__(self):
if self._handle != 0:
_delete_rng_handle(self._handle)
......@@ -599,6 +647,7 @@ gamma = _default_handle.gamma
beta = _default_handle.beta
poisson = _default_handle.poisson
permutation = _default_handle.permutation
shuffle = _default_handle.shuffle
def _random_seed_generator():
......
......@@ -18,6 +18,7 @@ from megengine.core._imperative_rt.ops import (
get_global_rng_seed,
new_rng_handle,
)
from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import (
BetaRNG,
GammaRNG,
......@@ -397,6 +398,45 @@ def test_PermutationRNG():
assert sum_result(out, np.sort) == 1000
@pytest.mark.skipif(
get_device_count("xpu") <= 1, reason="xpu counts need > 1",
)
def test_ShuffleRNG():
g = []
def cb(grad):
g.append(grad)
n, m = 6, 3
arr = np.arange(n * m)
out0 = Tensor(arr, dtype="float32")
grad = Grad().wrt(out0, callback=cb)
random.shuffle(out0)
grad(out0, F.ones_like(out0))
m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1")
m3 = RNG(seed=222, device="xpu0")
out1 = Tensor(arr, dtype="float32", device="xpu0")
out2 = Tensor(arr, dtype="float32", device="xpu1")
out3 = Tensor(arr, dtype="float32", device="xpu0")
m1.shuffle(out1)
m2.shuffle(out2)
m3.shuffle(out3)
np.testing.assert_equal(out1.numpy(), out2.numpy())
assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all()
out = Tensor(arr, dtype="float32").reshape(n, m)
m1.shuffle(out)
out_shp = out.shape
if isinstance(out_shp, tuple):
assert out_shp == (n, m)
else:
assert all(out.shape.numpy() == np.array([n, m]))
def test_seed():
set_global_seed(10)
out1 = uniform(size=[10, 10])
......
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/imperative/ops/rng.h"
......@@ -14,8 +15,8 @@
#include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h"
#include "../op_trait.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb::imperative::rng {
......@@ -259,13 +260,27 @@ struct OpMeth<BetaRNG> {
}
};
template <>
struct OpMeth<ShuffleRNG> {
using DnnOp = megdnn::ShuffleRNG;
using Param = DnnOp::Param;
using OpNode = mgb::opr::ShuffleRNG;
static Param make_param(const ShuffleRNG& rng) {
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu",
handle_seed, rng.seed);
return {handle_seed};
}
};
template <bool>
struct _InferLayout;
template <int nr_in>
struct _RNGOprMaker;
template <int nr_in>
template <int nr_in, int nr_out>
struct _RNGOprInvoker;
template<>
......@@ -316,50 +331,63 @@ struct _InferLayout<false>
return inp.layout;
}
};
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS) \
template<> \
struct _RNGOprInvoker<DNN_NR_INPUTS> { \
template<typename Opr> \
static void exec(Opr *dnn_op, const SmallVector<TensorPtr>& inputs,const TensorPtr& dest){ \
size_t wk_size = 0; \
wk_size = dnn_op->get_workspace_in_bytes(_FOR_EACH_IN(->layout())dest->layout()); \
auto workspace = Blob::make(dest->comp_node(), wk_size); \
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \
dest->dev_tensor().as_megdnn(), dnn_wk); \
} \
};
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS) \
template <> \
struct _RNGOprInvoker<DNN_NR_INPUTS, DNN_NR_OUTPUTS> { \
template <typename Opr> \
static void exec(Opr* dnn_op, const SmallVector<TensorPtr>& inputs, \
const SmallVector<TensorPtr>& outputs) { \
size_t wk_size = 0; \
wk_size = dnn_op->get_workspace_in_bytes( \
_FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout())); \
auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \
_FOR_EACH_OUT(->dev_tensor().as_megdnn()), \
dnn_wk); \
} \
};
#define _INST_RNG_MAKER(MGB_NR_INPUTS) \
template<> \
struct _RNGOprMaker<MGB_NR_INPUTS> { \
template<typename Op> \
static SymbolVar make(const VarNodeArray& inputs, const Op& rng){ \
auto param = OpMeth<Op>::make_param(rng); \
OperatorNodeConfig config; \
if (rng.handle) { \
config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \
} else { \
config = {rng.make_name()}; \
} \
return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \
} \
};
#define _INST_RNG_MAKER(MGB_NR_INPUTS) \
template <> \
struct _RNGOprMaker<MGB_NR_INPUTS> { \
template <typename Op> \
static auto make(const VarNodeArray& inputs, const Op& rng) { \
auto param = OpMeth<Op>::make_param(rng); \
OperatorNodeConfig config; \
if (rng.handle) { \
config = {rng.make_name(), \
RNGDnnOpManager::get_comp_node(rng.handle)}; \
} else { \
config = {rng.make_name()}; \
} \
return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \
} \
};
#define _FOR_EACH_IN(subfix)
_INST_RNG_INVOLKER(0)
#define _FOR_EACH_IN(subfix)
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(0, 1)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) inputs[0] subfix,
_INST_RNG_INVOLKER(1)
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(1, 1)
#undef _FOR_EACH_OUT
#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
_INST_RNG_INVOLKER(1, 2)
_INST_RNG_MAKER(1)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix,
_INST_RNG_INVOLKER(2)
#define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(2, 1)
_INST_RNG_MAKER(2)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN
#undef _INST_RNG_INVOLKER
......@@ -392,7 +420,9 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
handle_seed, dnn_op->param().seed);
}
dnn_op->param() = OpMeth<Op>::make_param(rng);
_RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS>::exec(dnn_op,inputs,dest);
_RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS,
OpMeth<Op>::DnnOp::NR_OUTPUTS>::exec(dnn_op, inputs,
outputs);
}
template <typename Op>
......@@ -420,24 +450,45 @@ SmallVector<LogicalTensorDesc> infer_output_attrs(
return {dest};
}
template <typename Op>
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc(
const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto &&dest = infer_output_attrs<Op>(def, inputs_tensors);
SmallVector<MemoryDesc> outputs = {{dest[0].layout, 0, dest[0].comp_node, StorageIdentifier::make(1)}};
return {outputs, {}};
template <>
SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>(
const OpDef& op, const SmallVector<TensorPtr>& inputs) {
SmallVector<LogicalTensorDesc> dests(2);
auto&& rng = op.cast_final_safe<ShuffleRNG>();
auto handle = rng.handle;
if (handle) {
dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle);
dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle);
} else {
dests[0].comp_node = inputs[0]->comp_node();
dests[1].comp_node = inputs[0]->comp_node();
}
dests[0].layout = TensorLayout(inputs[0]->layout());
dests[0].layout.dtype = inputs[0]->layout().dtype;
dests[1].layout =
TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32());
return dests;
}
template <typename Op>
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
infer_output_mem_desc(const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& dests = infer_output_attrs<Op>(def, inputs_tensors);
SmallVector<MemoryDesc> outputs;
for (size_t i = 0; i < dests.size(); ++i) {
outputs.push_back({dests[i].layout, 0, dests[i].comp_node,
StorageIdentifier::make(i + 1)});
}
return {outputs, {}};
}
template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<TensorPtr> outputs;
SmallVector<LogicalTensorDesc> desc;
desc = infer_output_attrs<Op>(def, inputs);
SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs);
for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node));
}
......@@ -454,10 +505,8 @@ void execute(
exec<Op>(def, inputs, outputs, {});
}
template<typename Op>
SymbolVar apply_on_var_node(
const OpDef& def,
const VarNodeArray& inputs) {
template <typename Op, typename Output>
Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
size_t nr_inp = inputs.size();
constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS;
auto&& rng = def.cast_final_safe<Op>();
......@@ -487,7 +536,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{dest}, true};
}
} // anonymous namespace
template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool>
infer_output_attrs_fallible<ShuffleRNG>(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
SmallVector<LogicalTensorDesc> dests(2);
dests[0].comp_node = inputs[0].comp_node;
dests[0].layout = TensorLayout(inputs[0].layout);
dests[0].layout.dtype = inputs[0].layout.dtype;
dests[1].comp_node = inputs[0].comp_node;
dests[1].layout = TensorLayout(TensorShape({inputs[0].layout.shape[0]}),
dtype::Int32());
return {dests, true};
}
} // anonymous namespace
Handle new_handle(CompNode comp_node, uint64_t seed) {
return RNGDnnOpManager::inst().new_handle(comp_node, seed);
......@@ -509,23 +572,24 @@ CompNode get_rng_handle_compnode(Handle handle){
return RNGDnnOpManager::get_comp_node(handle);
}
#define REG_RNG_OP(NAME)\
namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.apply_on_var_node(apply_on_var_node<NAME>) \
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.infer_output_mem_desc(infer_output_mem_desc<NAME>) \
.execute(execute<NAME>) \
.fallback(); \
} \
REG_RNG_OP(UniformRNG)
REG_RNG_OP(GaussianRNG)
REG_RNG_OP(GammaRNG)
REG_RNG_OP(PermutationRNG)
REG_RNG_OP(PoissonRNG)
REG_RNG_OP(BetaRNG)
#define REG_RNG_OP(NAME, Output) \
namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.apply_on_var_node(apply_on_var_node<NAME, Output>) \
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.infer_output_mem_desc(infer_output_mem_desc<NAME>) \
.execute(execute<NAME>) \
.fallback(); \
}
REG_RNG_OP(UniformRNG, SymbolVar)
REG_RNG_OP(GaussianRNG, SymbolVar)
REG_RNG_OP(GammaRNG, SymbolVar)
REG_RNG_OP(PermutationRNG, SymbolVar)
REG_RNG_OP(PoissonRNG, SymbolVar)
REG_RNG_OP(BetaRNG, SymbolVar)
REG_RNG_OP(ShuffleRNG, SymbolVarArray)
#undef REG_RNG_OP
} // namespace mgb::imperative::rng
......
......@@ -215,6 +215,19 @@ def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> {
let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}];
}
def ShuffleRNG: MgbHashableOp<"ShuffleRNG", [ShuffleRNGParam]> {
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash($_self.handle)
);
}];
let cmpFunction = [{return $0.handle == $1.handle;}];
}
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
let extraArguments = (ins
MgbCompNodeAttr:$comp_node
......
......@@ -192,6 +192,8 @@ template class RNGOprBase<::megdnn::GammaRNG>;
template class RNGOprBase<::megdnn::PermutationRNG>;
template class RNGOprBase<::megdnn::BetaRNG>;
template class RNGOprBase<::megdnn::PoissonRNG>;
template class RNGOprBase<::megdnn::ShuffleRNGForward>;
template class RNGOprBase<::megdnn::ShuffleRNGBackward>;
#if MGB_ENABLE_GRAD
IMPL(GaussianRNG);
IMPL(UniformRNG);
......@@ -200,9 +202,87 @@ IMPL(PoissonRNG);
IMPL(PermutationRNG);
IMPL(BetaRNG);
#endif
}
} // namespace intl
} // namespace opr
} // namespace mgb
/* ================= ShuffleRNGForward ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGForward);
ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param,
const OperatorNodeConfig& config)
: Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) {
add_input({data});
add_output(None)->dtype(data->dtype());
add_output(None)->dtype(dtype::Int32{});
cg::add_workspace_output(this);
add_equivalence_component<ScalarHash<void*>>(this);
}
SymbolVarArray ShuffleRNGForward::make(SymbolVar in_tensor, const Param& param,
const OperatorNodeConfig& config) {
auto node = in_tensor.node()->owner_graph()->insert_opr(
std::make_unique<ShuffleRNGForward>(in_tensor.node(), param,
config));
mgb_assert(node->output().size() == 3);
return {node->output(0), node->output(1)};
}
void ShuffleRNGForward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0),
ShapeInferDesc::make_identity(input(0)));
auto infer_oshp1 = [this](TensorShape& dest, const InpVal& iv) {
TensorLayout o0, o1;
m_dnn_opr->deduce_layout({iv.val[0].shape(), input(0)->dtype()}, o0,
o1);
dest = o1;
return true;
};
mgr.register_shape_infer(
output(1),
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_oshp1});
auto infer_wk = [this](TensorShape& dest, const InpVal& inp) {
ensure_megdnn_opr();
dest.ndim = 1;
dest.shape[0] = m_dnn_opr->get_workspace_in_bytes(
{inp.val[0].shape(), input(0)->dtype()},
{output(0)->shape(), output(0)->dtype()},
{output(1)->shape(), output(1)->dtype()});
return true;
};
mgr.register_shape_infer(
output(2),
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_wk});
}
void ShuffleRNGForward::add_input_layout_constraint() {
input(0)->add_layout_constraint_contiguous();
};
void ShuffleRNGForward::scn_do_execute() {
m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
get_megdnn_workspace_from_var(output(2)));
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ShuffleRNGForward) {
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);
if (!out_grad[0])
return nullptr;
return ShuffleRNGBackward::make(out_grad[0], opr.output(1), opr.input(0)).node();
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGBackward);
MEGDNN_OPR_INIT3(ShuffleRNGBackward, "shuffle_rng_bwd", 2, true)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -6,7 +6,8 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#include "megbrain/opr/rand.h"
......@@ -14,6 +15,23 @@
namespace mgb {
namespace serialization {
template <>
struct OprMaker<opr::ShuffleRNG, 1> {
using Opr = opr::ShuffleRNG;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
auto out = Opr::make(inputs[0], param, config);
return out[0].node()->owner_opr();
}
};
} // namespace serialization
namespace opr {
using UniformRNGV1 = opr::UniformRNG;
......@@ -24,9 +42,10 @@ MGB_SEREG_OPR(GammaRNG, 2);
MGB_SEREG_OPR(PoissonRNG, 1);
MGB_SEREG_OPR(PermutationRNG, 1);
MGB_SEREG_OPR(BetaRNG, 2);
MGB_SEREG_OPR(ShuffleRNG, 1);
MGB_SEREG_OPR(ShuffleRNGBackward, 3);
} // namespace opr
} // namespace mgb
} // namespace opr
} // namespace mgb
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -6,14 +6,15 @@
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/
#pragma once
#include "megbrain/graph.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megdnn/oprs.h"
namespace mgb {
......@@ -41,22 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // {
};
/* ================= RNG with shape ================= */
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS(RNG,RNGOprBase<megdnn::RNG>) \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
public: \
RNG(VarNode *shape, const Param &param, const OperatorNodeConfig &config); \
static SymbolVar make(SymbolVar shape, const Param &param = {}, \
const OperatorNodeConfig &config = {}); \
static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, \
const OperatorNodeConfig &config, \
const Param &param = {}) { \
return make(var_from_tensor_shape(graph, config, "rng", shape), \
param, config); \
} \
void init_output_static_infer_desc() override; \
void scn_do_execute() override; \
};
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
\
public: \
RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \
static SymbolVar make(SymbolVar shape, const Param& param = {}, \
const OperatorNodeConfig& config = {}); \
static SymbolVar make(ComputingGraph& graph, const TensorShape& shape, \
const OperatorNodeConfig& config, \
const Param& param = {}) { \
return make(var_from_tensor_shape(graph, config, "rng", shape), param, \
config); \
} \
void init_output_static_infer_desc() override; \
void scn_do_execute() override; \
} \
;
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG)
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG)
......@@ -71,7 +74,7 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>)
public: \
RNG(_INPUTS(VarNode*), const Param &param, \
const OperatorNodeConfig &config); \
static SymbolVar make(_INPUTS(SymbolVar),const Param &param = {}, \
static _OUTPUTS make(_INPUTS(SymbolVar),const Param &param = {}, \
const OperatorNodeConfig &config = {}); \
void init_output_static_infer_desc() override; \
void scn_do_execute() override; \
......@@ -79,17 +82,24 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>)
/* ================= 1 input ================= */
#define _INPUTS(preifx) preifx i0
#define _OUTPUTS SymbolVar
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(PoissonRNG)
#undef _OUTPUTS
#define _OUTPUTS SymbolVarArray
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(ShuffleRNGForward)
#undef _OUTPUTS
#undef _INPUTS
/* ================= 2 input ================= */
#define _INPUTS(preifx) preifx i0, preifx i1
#define _OUTPUTS SymbolVar
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG)
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG)
#undef _OUTPUTS
#undef _INPUTS
#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS
} // intl
} // intl
using UniformRNG = intl::UniformRNG;
using GaussianRNG = intl::GaussianRNG;
......@@ -97,9 +107,20 @@ using GammaRNG = intl::GammaRNG;
using PermutationRNG = intl::PermutationRNG;
using PoissonRNG = intl::PoissonRNG;
using BetaRNG = intl::BetaRNG;
} // namespace opr
} // namespace mgb
using ShuffleRNG = intl::ShuffleRNGForward;
MGB_DEFINE_OPR_CLASS(ShuffleRNGBackward,
intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) //{
public:
ShuffleRNGBackward(VarNode* out_diff, VarNode* indices, VarNode* result_shape,
const Param& param, const OperatorNodeConfig& config);
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
static SymbolVar make(SymbolVar out_diff, SymbolVar indices,
SymbolVar result_shape, const Param& param = {},
const OperatorNodeConfig& config = {});
};
} // namespace opr
} // namespace mgb
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -333,6 +333,38 @@ TEST(TestOprRand, EmptyShape) {
}
TEST(TestOprRand, ShuffleForward) {
auto run = [&](TensorShape shape) {
std::shared_ptr<HostTensorND> src_host(new HostTensorND{
CompNode::load("xpux"), shape, dtype::Float32()});
auto sptr = src_host->ptr<dt_float32>();
auto size = shape.total_nr_elems();
for (size_t i = 0; i < size; ++i) {
sptr[i] = i;
}
auto graph = ComputingGraph::make();
auto src_sym = opr::Host2DeviceCopy::make(*graph, src_host);
auto rec = opr::ShuffleRNG::make(src_sym, {10});
HostTensorND host_y, host_index;
auto func = graph->compile({make_callback_copy(rec[0], host_y),
make_callback_copy(rec[1], host_index)});
func->execute();
auto dptr = host_y.ptr<dt_float32>();
auto iptr = host_index.ptr<dt_int32>();
size_t len = shape[0];
size_t step = size / len;
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
assert(dptr[i * step + j] == sptr[iptr[i] * step + j]);
}
}
};
run({10});
run({6, 3});
run({1, 1});
}
TEST(TestOprRand, UniformReprod) {
static constexpr size_t SIZE = 123;
auto graph = ComputingGraph::make();
......
......@@ -114,6 +114,7 @@ union OperatorParam {
param.BetaRNG = 80,
param.SlidingWindowTranspose = 81,
param.Padding = 82,
param.ShuffleRNG = 83,
}
table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册