提交 3977b7aa 编写于 作者: M Megvii Engine Team

feat(mgb/shuffle): add shuffle opr

GitOrigin-RevId: 80490a6f848d524111bee097f11b591b5a3956c8
上级 17371e79
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include "megdnn/internal/opr_header_prologue.h" #include "megdnn/internal/opr_header_prologue.h"
...@@ -94,6 +95,42 @@ class PermutationRNG: public RNGBase { ...@@ -94,6 +95,42 @@ class PermutationRNG: public RNGBase {
void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); void check_exec(const TensorLayout &dst, size_t workspace_in_bytes);
}; };
class ShuffleRNGForward : public OperatorBase {
DEF_OPR_IMPL(ShuffleRNGForward, OperatorBase, 1, 2);
DEF_OPR_PARAM(ShuffleRNG);
public:
virtual void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices,
_megdnn_workspace workspace) = 0;
void deduce_layout(const TensorLayout& src, TensorLayout& dst,
TensorLayout& indices);
virtual size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& indices) = 0;
protected:
void check_exec(const TensorLayout& src, const TensorLayout& dst,
const TensorLayout& indices, size_t workspace_in_bytes);
};
using ShuffleRNG = ShuffleRNGForward;
class ShuffleRNGBackward : public OperatorBase {
DEF_OPR_IMPL(ShuffleRNGBackward, OperatorBase, 2, 1);
DEF_OPR_PARAM(ShuffleRNG);
public:
virtual void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
_megdnn_tensor_out grad, _megdnn_workspace workspace) = 0;
virtual size_t get_workspace_in_bytes(const TensorLayout& diff,
const TensorLayout& indices,
const TensorLayout& grad) = 0;
protected:
void check_exec(const TensorLayout& diff, const TensorLayout& indices,
const TensorLayout& grad, size_t workspace_in_bytes);
};
/*! /*!
* \brief sleep for specific time on the computing device; useful for testing * \brief sleep for specific time on the computing device; useful for testing
* async problems * async problems
......
...@@ -781,6 +781,9 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) ...@@ -781,6 +781,9 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0)
'Float32 are supported.'), 'Float32 are supported.'),
'DTypeEnum::Int32')) 'DTypeEnum::Int32'))
(pdef('ShuffleRNG').
add_fields('uint64', 'seed', 0))
(pdef('Flip'). (pdef('Flip').
add_fields('bool', 'vertical', 'false', 'horizontal', 'false')) add_fields('bool', 'vertical', 'false', 'horizontal', 'false'))
......
...@@ -165,6 +165,8 @@ private: ...@@ -165,6 +165,8 @@ private:
cb(BetaRNG) \ cb(BetaRNG) \
cb(PoissonRNG) \ cb(PoissonRNG) \
cb(PermutationRNG) \ cb(PermutationRNG) \
cb(ShuffleRNGForward) \
cb(ShuffleRNGBackward) \
cb(SeparableConvForward) \ cb(SeparableConvForward) \
cb(SeparableFilterForward) \ cb(SeparableFilterForward) \
cb(BNForward) \ cb(BNForward) \
......
...@@ -128,6 +128,8 @@ DEF(GammaRNG, 3, true, true); ...@@ -128,6 +128,8 @@ DEF(GammaRNG, 3, true, true);
DEF(BetaRNG, 3, true, true); DEF(BetaRNG, 3, true, true);
DEF(PoissonRNG, 2, true, true); DEF(PoissonRNG, 2, true, true);
DEF(PermutationRNG, 1, true, true); DEF(PermutationRNG, 1, true, true);
DEF(ShuffleRNGForward, 3, true, true);
DEF(ShuffleRNGBackward, 3, true, false);
DEF(ChecksumForward, 1, true, false); DEF(ChecksumForward, 1, true, false);
DEF(CheckHasInf, 2, true, true); DEF(CheckHasInf, 2, true, true);
DEF(LSQForward, 5, true, true); DEF(LSQForward, 5, true, true);
......
...@@ -15,6 +15,47 @@ ...@@ -15,6 +15,47 @@
namespace megdnn { namespace megdnn {
void ShuffleRNGForward::deduce_layout(const TensorLayout& src,
TensorLayout& dst,
TensorLayout& indices) {
dst = src;
indices = TensorLayout(TensorShape({src.shape[0]}), dtype::Int32());
}
void ShuffleRNGForward::check_exec(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& indices,
size_t workspace_in_bytes) {
TensorLayout dst_expected, indices_expected;
megdnn_assert_contiguous(src);
deduce_layout(src, dst_expected, indices_expected);
megdnn_assert_eq_layout(dst_expected, dst);
megdnn_assert_eq_layout(indices_expected, indices);
megdnn_assert_contiguous(indices);
megdnn_assert(src.dtype == dst.dtype);
megdnn_assert(indices.dtype == dtype::Int32());
auto required_workspace_in_bytes =
get_workspace_in_bytes(src, dst, indices);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void ShuffleRNGBackward::check_exec(const TensorLayout& diff,
const TensorLayout& indices,
const TensorLayout& grad,
size_t workspace_in_bytes) {
megdnn_assert(
diff.shape[0] == indices.shape[0] && diff.dtype == grad.dtype &&
indices.dtype == dtype::Int32{} && diff.is_contiguous() &&
indices.is_contiguous() && grad.is_contiguous(),
"invalid layouts: diff=%s indices=%s grad=%s",
diff.to_string().c_str(), indices.to_string().c_str(),
grad.to_string().c_str());
auto required_workspace_in_bytes =
get_workspace_in_bytes(diff, indices, grad);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void PermutationRNG::check_exec( void PermutationRNG::check_exec(
const TensorLayout &dst, size_t workspace_in_bytes) { const TensorLayout &dst, size_t workspace_in_bytes) {
megdnn_assert((dst.dtype == dtype::Float32() || megdnn_assert((dst.dtype == dtype::Float32() ||
......
...@@ -55,6 +55,42 @@ __global__ void permute_duplicate_keys_kernel(KeyType* keys, ValueType* indexs, ...@@ -55,6 +55,42 @@ __global__ void permute_duplicate_keys_kernel(KeyType* keys, ValueType* indexs,
} }
} }
template <typename T>
__global__ void shuffle_fwd_kernel(uint32_t step, uint32_t src_size, const T* sptr,
T* dptr, const int* iptr) {
uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < src_size) {
uint32_t r = idx / step;
dptr[idx]=sptr[iptr[r] * step + idx % step];
}
}
template <typename T>
void shuffle_forward(T* sptr, T* dptr, dt_int32* iptr,
size_t len, size_t step, cudaStream_t stream) {
uint32_t src_size = len * step;
shuffle_fwd_kernel<<<DIVUP(src_size, 512), 512, 0, stream>>>(
step, src_size, sptr, dptr, iptr);
after_kernel_launch();
}
template <typename T>
__global__ void shuffle_bwd_kernel(uint32_t step, uint32_t src_size, T* sptr,
T* dptr, const int* iptr) {
uint32_t idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < src_size) {
uint32_t r = idx / step;
sptr[iptr[r] * step + idx % step]=dptr[idx];
}
}
template <typename T>
void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr,
size_t len, size_t step, cudaStream_t stream) {
uint32_t src_size = len * step;
shuffle_bwd_kernel<<<DIVUP(src_size, 512), 512, 0, stream>>>(
step, src_size, sptr, dptr, iptr);
after_kernel_launch();
}
uint32_t get_permutation_bits(size_t N) { uint32_t get_permutation_bits(size_t N) {
double uniq_rand_num_prob = 0.9; double uniq_rand_num_prob = 0.9;
double thresh = std::log(uniq_rand_num_prob) * 12; double thresh = std::log(uniq_rand_num_prob) * 12;
...@@ -156,6 +192,14 @@ INST_PERMUTATION(dt_int16) ...@@ -156,6 +192,14 @@ INST_PERMUTATION(dt_int16)
INST_PERMUTATION(dt_float32) INST_PERMUTATION(dt_float32)
#undef INST_PERMUTATION #undef INST_PERMUTATION
#define INST_SHUFFLE(T) \
template void shuffle_forward<T>(T* sptr, T* dptr, dt_int32* iptr,\
size_t len, size_t step, cudaStream_t stream);\
template void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr,\
size_t len, size_t step, cudaStream_t stream);
ARGSORT_FOREACH_CTYPE(INST_SHUFFLE)
#undef INST_SHUFFLE
} // namespace random } // namespace random
#define INST(_dtype) \ #define INST(_dtype) \
......
...@@ -253,6 +253,17 @@ void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed ...@@ -253,6 +253,17 @@ void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed
size_t get_permutation_workspace_in_bytes(size_t N); size_t get_permutation_workspace_in_bytes(size_t N);
template<typename T>
void shuffle_forward(T* sptr, T* dptr, dt_int32* iptr,
size_t len, size_t step, cudaStream_t stream);
template<typename T>
void shuffle_backward(T* dptr, dt_int32* iptr, T* sptr,
size_t len, size_t step, cudaStream_t stream);
#define ARGSORT_FOREACH_CTYPE(cb) \
cb(float) cb(int32_t) DNN_INC_FLOAT16(cb(dt_float16))
} // namespace random } // namespace random
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
...@@ -9,11 +9,11 @@ ...@@ -9,11 +9,11 @@
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/ */
#include "./opr_impl.h"
#include "./kernel.cuh"
#include "src/common/utils.h" #include "src/common/utils.h"
#include "src/cuda/handle.h" #include "src/cuda/handle.h"
#include "src/cuda/utils.h" #include "src/cuda/utils.h"
#include "./opr_impl.h"
#include "./kernel.cuh"
using namespace megdnn; using namespace megdnn;
using namespace cuda; using namespace cuda;
...@@ -261,5 +261,76 @@ size_t PermutationRNGImpl::get_workspace_in_bytes(const TensorLayout &layout){ ...@@ -261,5 +261,76 @@ size_t PermutationRNGImpl::get_workspace_in_bytes(const TensorLayout &layout){
return random::get_permutation_workspace_in_bytes(size); return random::get_permutation_workspace_in_bytes(size);
} }
ShuffleRNGForwardImpl::ShuffleRNGForwardImpl(Handle* handle)
: ShuffleRNGForward(handle),
m_seed(0),
m_offset(0),
m_stream(cuda_stream(handle)) {}
void ShuffleRNGForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, indices.layout, workspace.size);
ensure_seed(m_param.seed);
auto wk = workspace.ptr<void>();
const auto len = indices.layout[0];
random::permutation_forward<dt_int32>(indices.ptr<dt_int32>(), wk, len,
m_seed, m_offset, m_stream);
size_t step = 0;
for (size_t i = 1; i < src.layout.ndim; ++i) {
step += src.layout[i];
}
if (step <= 0)
step = 1;
switch (src.layout.dtype.enumv()) {
#define cb(DType) \
case DTypeTrait<DType>::enumv: \
random::shuffle_forward<DTypeTrait<DType>::ctype>( \
src.ptr<DTypeTrait<DType>::ctype>(), \
dst.ptr<DTypeTrait<DType>::ctype>(), indices.ptr<dt_int32>(), \
len, step, m_stream); \
break;
ARGSORT_FOREACH_CTYPE(cb)
#undef cb
default : megdnn_throw("bad dtype");
}
m_offset += 8;
}
size_t ShuffleRNGForwardImpl::get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&,
const TensorLayout& indices) {
size_t size = indices.total_nr_elems();
return random::get_permutation_workspace_in_bytes(size);
}
ShuffleRNGBackwardImpl::ShuffleRNGBackwardImpl(Handle* handle)
: ShuffleRNGBackward(handle), m_stream(cuda_stream(handle)) {}
void ShuffleRNGBackwardImpl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in indices,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
const auto len = indices.layout[0];
auto step = 0;
for (size_t i = 1; i < diff.layout.ndim; ++i) {
step += diff.layout[i];
}
if (step <= 0)
step = 1;
switch (diff.layout.dtype.enumv()) {
#define cb(DType) \
case DTypeTrait<DType>::enumv: \
random::shuffle_backward<DTypeTrait<DType>::ctype>( \
diff.ptr<DTypeTrait<DType>::ctype>(), indices.ptr<dt_int32>(), \
grad.ptr<DTypeTrait<DType>::ctype>(), len, step, m_stream); \
break;
ARGSORT_FOREACH_CTYPE(cb)
#undef cb
default:
megdnn_throw("bad dtype");
}
}
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
...@@ -152,6 +153,45 @@ public: ...@@ -152,6 +153,45 @@ public:
} }
}; };
class ShuffleRNGForwardImpl : public ShuffleRNGForward {
uint64_t m_seed, m_offset;
cudaStream_t m_stream;
public:
using ShuffleRNGForward::ShuffleRNGForward;
ShuffleRNGForwardImpl(Handle* handle);
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout& src,
const TensorLayout& dst,
const TensorLayout& indices) override;
void seed(uint64_t seed) { m_seed = seed; }
void ensure_seed(uint64_t seed) {
if (m_seed != seed) {
this->seed(seed);
}
}
};
class ShuffleRNGBackwardImpl : public ShuffleRNGBackward {
cudaStream_t m_stream;
public:
using ShuffleRNGBackward::ShuffleRNGBackward;
ShuffleRNGBackwardImpl(Handle* handle);
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
} // namespace cuda } // namespace cuda
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -6,12 +6,13 @@ ...@@ -6,12 +6,13 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "src/naive/handle.h"
#include "src/common/utils.h"
#include "./opr_impl.h" #include "./opr_impl.h"
#include "src/common/utils.h"
#include "src/naive/handle.h"
#include <cmath> #include <cmath>
...@@ -229,7 +230,29 @@ namespace { ...@@ -229,7 +230,29 @@ namespace {
} }
} }
} // anonymous namespace template <typename T>
void shuffle_fwd(const T* __restrict sptr, T* __restrict dptr,
const dt_int32* iptr, const size_t len,
const size_t step) MEGDNN_NOEXCEPT {
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
dptr[i * step + j] = sptr[iptr[i] * step + j];
}
}
}
template <typename T>
void shuffle_bwd(T* __restrict sptr, const T* __restrict dptr,
const dt_int32* iptr, const size_t len,
const size_t step) MEGDNN_NOEXCEPT {
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
sptr[iptr[i] * step + j] = dptr[i * step + j];
}
}
}
} // anonymous namespace
uint64_t Splitmix64::operator() () { uint64_t Splitmix64::operator() () {
uint64_t z = (m_s += UINT64_C(0x9E3779B97F4A7C15)); uint64_t z = (m_s += UINT64_C(0x9E3779B97F4A7C15));
...@@ -394,5 +417,54 @@ void PermutationRNGImpl::exec( ...@@ -394,5 +417,54 @@ void PermutationRNGImpl::exec(
} }
} }
// vim: syntax=cpp.doxygen void ShuffleRNGForwardImpl::exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices,
_megdnn_workspace workspace) {
check_exec(src.layout, dst.layout, indices.layout, workspace.size);
const auto len = indices.layout[0];
auto iptr = indices.ptr<dt_int32>();
auto prng = &m_rng.ensure_seed(m_param.seed);
fill_permutation<dt_int32>(prng, iptr, len);
auto step = 0;
for (size_t i = 1; i < src.layout.ndim; ++i) {
step += src.layout[i];
}
if (step <= 0)
step = 1;
#define cb(DType) \
if (src.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR( \
shuffle_fwd<T>(src.ptr<T>(), dst.ptr<T>(), iptr, len, step)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}
void ShuffleRNGBackwardImpl::exec(_megdnn_tensor_in diff,
_megdnn_tensor_in indices,
_megdnn_tensor_out grad,
_megdnn_workspace workspace) {
check_exec(diff.layout, indices.layout, grad.layout, workspace.size);
const auto len = indices.layout[0];
auto iptr = indices.ptr<dt_int32>();
auto step = 0;
for (size_t i = 1; i < diff.layout.ndim; ++i) {
step += diff.layout[i];
}
if (step <= 0)
step = 1;
#define cb(DType) \
if (diff.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
MEGDNN_DISPATCH_CPU_KERN_OPR(shuffle_bwd<T>( \
grad.ptr<T>(), diff.ptr<T>(), iptr, len, step)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE(cb)
#undef cb
}
// vim: syntax=cpp.doxygen
...@@ -128,6 +128,35 @@ public: ...@@ -128,6 +128,35 @@ public:
size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; }
}; };
class ShuffleRNGForwardImpl : public ShuffleRNGForward {
Xoroshiro128plus m_rng;
public:
using ShuffleRNGForward::ShuffleRNGForward;
void exec(_megdnn_tensor_in src, _megdnn_tensor_out dst,
_megdnn_tensor_out indices, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
class ShuffleRNGBackwardImpl : public ShuffleRNGBackward {
Xoroshiro128plus m_rng;
public:
using ShuffleRNGBackward::ShuffleRNGBackward;
void exec(_megdnn_tensor_in diff, _megdnn_tensor_in indices,
_megdnn_tensor_out grad, _megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(const TensorLayout&,
const TensorLayout&,
const TensorLayout&) override {
return 0;
}
};
} // namespace naive } // namespace naive
} // namespace megdnn } // namespace megdnn
// vim: syntax=cpp.doxygen // vim: syntax=cpp.doxygen
...@@ -143,6 +143,60 @@ void run_permutation(Handle* handle) { ...@@ -143,6 +143,60 @@ void run_permutation(Handle* handle) {
} }
} }
template <typename T>
void run_shuffle(Handle* handle, bool bwd_flag) {
using ctype = typename DTypeTrait<T>::ctype;
auto run = [&](TensorShape shape) {
auto opr = handle->create_operator<ShuffleRNGForward>();
TensorLayout srclay{shape, T()};
TensorLayout dstlay{shape, T()};
TensorLayout indexlay{TensorShape{shape[0]}, dtype::Int32()};
Tensor<dt_byte> workspace(
handle, {TensorShape{opr->get_workspace_in_bytes(srclay, dstlay,
indexlay)},
dtype::Byte()});
SyncedTensor<ctype> src(handle, srclay);
SyncedTensor<ctype> dst(handle, dstlay);
SyncedTensor<DTypeTrait<dt_int32>::ctype> index(handle, indexlay);
auto sptr = src.ptr_mutable_host();
size_t size = src.layout().total_nr_elems();
for (size_t j = 0; j < size; ++j) {
sptr[j] = j;
}
opr->exec(src.tensornd_dev(), dst.tensornd_dev(), index.tensornd_dev(),
{workspace.ptr(), workspace.layout().total_nr_elems()});
auto dptr = dst.ptr_mutable_host();
auto iptr = index.ptr_mutable_host();
size_t len = index.layout().total_nr_elems();
size_t step = size / len;
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]);
}
}
if (bwd_flag) {
for (size_t j = 0; j < size; ++j) {
sptr[j] = 0;
}
auto oprbwd = handle->create_operator<ShuffleRNGBackward>();
oprbwd->exec(
dst.tensornd_dev(), index.tensornd_dev(),
src.tensornd_dev(),
{workspace.ptr(), workspace.layout().total_nr_elems()});
auto sptr_bwd = src.ptr_mutable_host();
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
ASSERT_EQ(dptr[i * step + j], sptr_bwd[iptr[i] * step + j]);
}
}
}
};
run({10});
run({6, 3});
}
} // anonymous namespace } // anonymous namespace
TEST_F(CUDA, UNIFORM_RNG_F32) { TEST_F(CUDA, UNIFORM_RNG_F32) {
...@@ -215,6 +269,30 @@ TEST_F(CUDA, PERMUTATION_RNG_INT16) { ...@@ -215,6 +269,30 @@ TEST_F(CUDA, PERMUTATION_RNG_INT16) {
run_permutation<dtype::Int16>(handle_cuda()); run_permutation<dtype::Int16>(handle_cuda());
} }
TEST_F(CUDA, SHUFFLE_RNG_F32) {
run_shuffle<dtype::Float32>(handle_cuda(), false);
}
TEST_F(CUDA, SHUFFLE_RNG_INT32) {
run_shuffle<dtype::Int32>(handle_cuda(), false);
}
TEST_F(CUDA, SHUFFLE_RNG_F16) {
run_shuffle<dtype::Float16>(handle_cuda(), false);
}
TEST_F(CUDA, SHUFFLE_RNG_BWD_F32) {
run_shuffle<dtype::Float32>(handle_cuda(), true);
}
TEST_F(CUDA, SHUFFLE_RNG_BWD_INT32) {
run_shuffle<dtype::Int32>(handle_cuda(), true);
}
TEST_F(CUDA, SHUFFLE_RNG_BWD_F16) {
run_shuffle<dtype::Float16>(handle_cuda(), true);
}
} // namespace test } // namespace test
} // namespace megdnn } // namespace megdnn
......
...@@ -6,12 +6,13 @@ ...@@ -6,12 +6,13 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "megdnn.h"
#include "test/naive/fixture.h"
#include "test/naive/rng.h" #include "test/naive/rng.h"
#include "megdnn.h"
#include "test/common/tensor.h" #include "test/common/tensor.h"
#include "test/naive/fixture.h"
namespace megdnn { namespace megdnn {
...@@ -181,7 +182,59 @@ namespace { ...@@ -181,7 +182,59 @@ namespace {
ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8); ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8);
} }
} }
}
template <typename T>
void run_shuffle(Handle* handle, bool bwd_flag) {
using ctype = typename DTypeTrait<T>::ctype;
auto run = [&](TensorShape shape) {
auto opr = handle->create_operator<ShuffleRNGForward>();
TensorLayout srclay{shape, T()};
TensorLayout dstlay{shape, T()};
TensorLayout indexlay{TensorShape{shape[0]}, dtype::Int32()};
Tensor<dt_byte> workspace(
handle, {TensorShape{opr->get_workspace_in_bytes(srclay, dstlay,
indexlay)},
dtype::Byte()});
Tensor<ctype> src(handle, srclay);
Tensor<ctype> dst(handle, dstlay);
Tensor<DTypeTrait<dt_int32>::ctype> index(handle, indexlay);
auto sptr = src.ptr();
size_t size = src.layout().total_nr_elems();
for (size_t j = 0; j < size; ++j) {
sptr[j] = j;
}
opr->exec(src.tensornd(), dst.tensornd(), index.tensornd(),
{workspace.ptr(), workspace.layout().total_nr_elems()});
auto dptr = dst.ptr();
auto iptr = index.ptr();
size_t len = index.layout().total_nr_elems();
size_t step = size / len;
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]);
}
}
if (bwd_flag) {
for (size_t j = 0; j < size; ++j) {
sptr[j] = 0;
}
auto oprbwd = handle->create_operator<ShuffleRNGBackward>();
oprbwd->exec(
dst.tensornd(), index.tensornd(), src.tensornd(),
{workspace.ptr(), workspace.layout().total_nr_elems()});
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
ASSERT_EQ(dptr[i * step + j], sptr[iptr[i] * step + j]);
}
}
}
};
run({10});
run({6, 3});
}
} // namespace
TEST_F(NAIVE, UNIFORM_RNG_F32) { TEST_F(NAIVE, UNIFORM_RNG_F32) {
run_uniform<dtype::Float32>(handle()); run_uniform<dtype::Float32>(handle());
...@@ -235,10 +288,31 @@ TEST_F(NAIVE, PERMUTATION_RNG_INT16) { ...@@ -235,10 +288,31 @@ TEST_F(NAIVE, PERMUTATION_RNG_INT16) {
run_permutation<dtype::Int16>(handle()); run_permutation<dtype::Int16>(handle());
} }
} // namespace test TEST_F(NAIVE, SHUFFLE_RNG_FWD_F32) {
} // namespace megdnn run_shuffle<dtype::Float32>(handle(), false);
}
// vim: syntax=cpp.doxygen TEST_F(NAIVE, SHUFFLE_RNG_FWD_INT32) {
run_shuffle<dtype::Int32>(handle(), false);
}
TEST_F(NAIVE, SHUFFLE_RNG_FWD_F16) {
run_shuffle<dtype::Float16>(handle(), false);
}
TEST_F(NAIVE, SHUFFLE_RNG_BWD_F32) {
run_shuffle<dtype::Float32>(handle(), true);
}
TEST_F(NAIVE, SHUFFLE_RNG_BWD_INT32) {
run_shuffle<dtype::Int32>(handle(), true);
}
TEST_F(NAIVE, SHUFFLE_RNG_BWD_F16) {
run_shuffle<dtype::Float16>(handle(), true);
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
# Unless required by applicable law or agreed to in writing, # Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an # software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, uniform from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, shuffle, uniform
__all__ = [ __all__ = [
"RNG", "RNG",
...@@ -17,6 +17,7 @@ __all__ = [ ...@@ -17,6 +17,7 @@ __all__ = [
"poisson", "poisson",
"seed", "seed",
"uniform", "uniform",
"shuffle",
] ]
# pylint: disable=undefined-variable # pylint: disable=undefined-variable
del rng # type: ignore[name-defined] del rng # type: ignore[name-defined]
...@@ -27,6 +27,7 @@ from ..core.ops.builtin import ( ...@@ -27,6 +27,7 @@ from ..core.ops.builtin import (
GaussianRNG, GaussianRNG,
PermutationRNG, PermutationRNG,
PoissonRNG, PoissonRNG,
ShuffleRNG,
UniformRNG, UniformRNG,
) )
from ..core.tensor import utils from ..core.tensor import utils
...@@ -41,6 +42,7 @@ __all__ = [ ...@@ -41,6 +42,7 @@ __all__ = [
"beta", "beta",
"poisson", "poisson",
"permutation", "permutation",
"shuffle",
] ]
_rng = None _rng = None
...@@ -219,6 +221,13 @@ def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Ten ...@@ -219,6 +221,13 @@ def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Ten
return output return output
def _shuffle(inp: Tensor, seed: int, handle: int) -> Tensor:
assert inp.size > 0, "size needs to be greater than 0"
op = ShuffleRNG(seed=seed, handle=handle)
output, _ = apply(op, inp)
inp._reset(output)
class RNG: class RNG:
r""":class:`RNG` exposes a number of methods for generating random numbers. r""":class:`RNG` exposes a number of methods for generating random numbers.
...@@ -581,6 +590,45 @@ class RNG: ...@@ -581,6 +590,45 @@ class RNG:
n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype
) )
def shuffle(self, inp: Tensor):
r"""Modify a sequence in-place by shuffling its contents.
This function only shuffles the Tensor along the first axis of a multi-dimensional Tensor.
The order of sub-Tensors is changed but their contents remains the same.
Args:
inp: input tensor.
Examples:
.. testcode::
import numpy as np
import megengine as mge
import megengine.random as rand
x = mge.tensor(np.arange(10))
rand.shuffle(x)
print(x.numpy())
y = mge.tensor(np.arange(18)).reshape(6,3)
rand.shuffle(y)
print(y.numpy())
Outputs:
.. testoutput::
:options: +SKIP
[7 9 3 0 8 2 4 5 6 1]
[[12. 13. 14.]
[ 3. 4. 5.]
[15. 16. 17.]
[ 0. 1. 2.]
[ 9. 10. 11.]
[ 6. 7. 8.]]
"""
_seed = self._seed() if callable(self._seed) else self._seed
_shuffle(inp=inp, seed=_seed, handle=self._handle)
def __del__(self): def __del__(self):
if self._handle != 0: if self._handle != 0:
_delete_rng_handle(self._handle) _delete_rng_handle(self._handle)
...@@ -599,6 +647,7 @@ gamma = _default_handle.gamma ...@@ -599,6 +647,7 @@ gamma = _default_handle.gamma
beta = _default_handle.beta beta = _default_handle.beta
poisson = _default_handle.poisson poisson = _default_handle.poisson
permutation = _default_handle.permutation permutation = _default_handle.permutation
shuffle = _default_handle.shuffle
def _random_seed_generator(): def _random_seed_generator():
......
...@@ -18,6 +18,7 @@ from megengine.core._imperative_rt.ops import ( ...@@ -18,6 +18,7 @@ from megengine.core._imperative_rt.ops import (
get_global_rng_seed, get_global_rng_seed,
new_rng_handle, new_rng_handle,
) )
from megengine.core.autodiff.grad import Grad
from megengine.core.ops.builtin import ( from megengine.core.ops.builtin import (
BetaRNG, BetaRNG,
GammaRNG, GammaRNG,
...@@ -397,6 +398,45 @@ def test_PermutationRNG(): ...@@ -397,6 +398,45 @@ def test_PermutationRNG():
assert sum_result(out, np.sort) == 1000 assert sum_result(out, np.sort) == 1000
@pytest.mark.skipif(
get_device_count("xpu") <= 1, reason="xpu counts need > 1",
)
def test_ShuffleRNG():
g = []
def cb(grad):
g.append(grad)
n, m = 6, 3
arr = np.arange(n * m)
out0 = Tensor(arr, dtype="float32")
grad = Grad().wrt(out0, callback=cb)
random.shuffle(out0)
grad(out0, F.ones_like(out0))
m1 = RNG(seed=111, device="xpu0")
m2 = RNG(seed=111, device="xpu1")
m3 = RNG(seed=222, device="xpu0")
out1 = Tensor(arr, dtype="float32", device="xpu0")
out2 = Tensor(arr, dtype="float32", device="xpu1")
out3 = Tensor(arr, dtype="float32", device="xpu0")
m1.shuffle(out1)
m2.shuffle(out2)
m3.shuffle(out3)
np.testing.assert_equal(out1.numpy(), out2.numpy())
assert out1.device == "xpu0" and out2.device == "xpu1"
assert not (out1.numpy() == out3.numpy()).all()
out = Tensor(arr, dtype="float32").reshape(n, m)
m1.shuffle(out)
out_shp = out.shape
if isinstance(out_shp, tuple):
assert out_shp == (n, m)
else:
assert all(out.shape.numpy() == np.array([n, m]))
def test_seed(): def test_seed():
set_global_seed(10) set_global_seed(10)
out1 = uniform(size=[10, 10]) out1 = uniform(size=[10, 10])
......
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "megbrain/imperative/ops/rng.h" #include "megbrain/imperative/ops/rng.h"
...@@ -14,8 +15,8 @@ ...@@ -14,8 +15,8 @@
#include "megbrain/graph/helper.h" #include "megbrain/graph/helper.h"
#include "megbrain/opr/rand.h" #include "megbrain/opr/rand.h"
#include "../op_trait.h"
#include "../dnn_op_helper.h" #include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb::imperative::rng { namespace mgb::imperative::rng {
...@@ -259,13 +260,27 @@ struct OpMeth<BetaRNG> { ...@@ -259,13 +260,27 @@ struct OpMeth<BetaRNG> {
} }
}; };
template <>
struct OpMeth<ShuffleRNG> {
using DnnOp = megdnn::ShuffleRNG;
using Param = DnnOp::Param;
using OpNode = mgb::opr::ShuffleRNG;
static Param make_param(const ShuffleRNG& rng) {
auto handle_seed = RNGDnnOpManager::get_seed(rng.handle);
mgb_assert(handle_seed == rng.seed,
"inconsistent rng seed: rng op: %lu handle: %lu",
handle_seed, rng.seed);
return {handle_seed};
}
};
template <bool> template <bool>
struct _InferLayout; struct _InferLayout;
template <int nr_in> template <int nr_in>
struct _RNGOprMaker; struct _RNGOprMaker;
template <int nr_in> template <int nr_in, int nr_out>
struct _RNGOprInvoker; struct _RNGOprInvoker;
template<> template<>
...@@ -316,50 +331,63 @@ struct _InferLayout<false> ...@@ -316,50 +331,63 @@ struct _InferLayout<false>
return inp.layout; return inp.layout;
} }
}; };
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS) \
template<> \
struct _RNGOprInvoker<DNN_NR_INPUTS> { \
template<typename Opr> \
static void exec(Opr *dnn_op, const SmallVector<TensorPtr>& inputs,const TensorPtr& dest){ \
size_t wk_size = 0; \
wk_size = dnn_op->get_workspace_in_bytes(_FOR_EACH_IN(->layout())dest->layout()); \
auto workspace = Blob::make(dest->comp_node(), wk_size); \
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \
dest->dev_tensor().as_megdnn(), dnn_wk); \
} \
};
#define _INST_RNG_INVOLKER(DNN_NR_INPUTS, DNN_NR_OUTPUTS) \
template <> \
struct _RNGOprInvoker<DNN_NR_INPUTS, DNN_NR_OUTPUTS> { \
template <typename Opr> \
static void exec(Opr* dnn_op, const SmallVector<TensorPtr>& inputs, \
const SmallVector<TensorPtr>& outputs) { \
size_t wk_size = 0; \
wk_size = dnn_op->get_workspace_in_bytes( \
_FOR_EACH_IN(->layout()) _FOR_EACH_OUT(->layout())); \
auto workspace = Blob::make(outputs[0]->comp_node(), wk_size); \
megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \
dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \
_FOR_EACH_OUT(->dev_tensor().as_megdnn()), \
dnn_wk); \
} \
};
#define _INST_RNG_MAKER(MGB_NR_INPUTS) \ #define _INST_RNG_MAKER(MGB_NR_INPUTS) \
template<> \ template <> \
struct _RNGOprMaker<MGB_NR_INPUTS> { \ struct _RNGOprMaker<MGB_NR_INPUTS> { \
template<typename Op> \ template <typename Op> \
static SymbolVar make(const VarNodeArray& inputs, const Op& rng){ \ static auto make(const VarNodeArray& inputs, const Op& rng) { \
auto param = OpMeth<Op>::make_param(rng); \ auto param = OpMeth<Op>::make_param(rng); \
OperatorNodeConfig config; \ OperatorNodeConfig config; \
if (rng.handle) { \ if (rng.handle) { \
config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \ config = {rng.make_name(), \
} else { \ RNGDnnOpManager::get_comp_node(rng.handle)}; \
config = {rng.make_name()}; \ } else { \
} \ config = {rng.make_name()}; \
return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \ } \
} \ return OpMeth<Op>::OpNode::make(_FOR_EACH_IN() param, config); \
}; } \
};
#define _FOR_EACH_IN(subfix) #define _FOR_EACH_IN(subfix)
_INST_RNG_INVOLKER(0) #define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(0, 1)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN #undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) inputs[0] subfix, #define _FOR_EACH_IN(subfix) inputs[0] subfix,
_INST_RNG_INVOLKER(1) #define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(1, 1)
#undef _FOR_EACH_OUT
#define _FOR_EACH_OUT(subfix) outputs[0] subfix, outputs[1] subfix
_INST_RNG_INVOLKER(1, 2)
_INST_RNG_MAKER(1) _INST_RNG_MAKER(1)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN #undef _FOR_EACH_IN
#define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix, #define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix,
_INST_RNG_INVOLKER(2) #define _FOR_EACH_OUT(subfix) outputs[0] subfix
_INST_RNG_INVOLKER(2, 1)
_INST_RNG_MAKER(2) _INST_RNG_MAKER(2)
#undef _FOR_EACH_OUT
#undef _FOR_EACH_IN #undef _FOR_EACH_IN
#undef _INST_RNG_INVOLKER #undef _INST_RNG_INVOLKER
...@@ -392,7 +420,9 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs, ...@@ -392,7 +420,9 @@ void exec(const OpDef& op, const SmallVector<TensorPtr>& inputs,
handle_seed, dnn_op->param().seed); handle_seed, dnn_op->param().seed);
} }
dnn_op->param() = OpMeth<Op>::make_param(rng); dnn_op->param() = OpMeth<Op>::make_param(rng);
_RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS>::exec(dnn_op,inputs,dest); _RNGOprInvoker<OpMeth<Op>::DnnOp::NR_INPUTS,
OpMeth<Op>::DnnOp::NR_OUTPUTS>::exec(dnn_op, inputs,
outputs);
} }
template <typename Op> template <typename Op>
...@@ -420,24 +450,45 @@ SmallVector<LogicalTensorDesc> infer_output_attrs( ...@@ -420,24 +450,45 @@ SmallVector<LogicalTensorDesc> infer_output_attrs(
return {dest}; return {dest};
} }
template <typename Op> template <>
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>> infer_output_mem_desc( SmallVector<LogicalTensorDesc> infer_output_attrs<ShuffleRNG>(
const OpDef& def, const OpDef& op, const SmallVector<TensorPtr>& inputs) {
const SmallVector<TensorPtr>& inputs_tensors, SmallVector<LogicalTensorDesc> dests(2);
const SmallVector<MemoryDesc>& inputs_mems) { auto&& rng = op.cast_final_safe<ShuffleRNG>();
auto &&dest = infer_output_attrs<Op>(def, inputs_tensors); auto handle = rng.handle;
SmallVector<MemoryDesc> outputs = {{dest[0].layout, 0, dest[0].comp_node, StorageIdentifier::make(1)}}; if (handle) {
dests[0].comp_node = RNGDnnOpManager::get_comp_node(handle);
return {outputs, {}}; dests[1].comp_node = RNGDnnOpManager::get_comp_node(handle);
} else {
dests[0].comp_node = inputs[0]->comp_node();
dests[1].comp_node = inputs[0]->comp_node();
}
dests[0].layout = TensorLayout(inputs[0]->layout());
dests[0].layout.dtype = inputs[0]->layout().dtype;
dests[1].layout =
TensorLayout(TensorShape({inputs[0]->layout()[0]}), dtype::Int32());
return dests;
} }
template <typename Op>
std::tuple<SmallVector<MemoryDesc>, SmallVector<MemoryDesc>>
infer_output_mem_desc(const OpDef& def,
const SmallVector<TensorPtr>& inputs_tensors,
const SmallVector<MemoryDesc>& inputs_mems) {
auto&& dests = infer_output_attrs<Op>(def, inputs_tensors);
SmallVector<MemoryDesc> outputs;
for (size_t i = 0; i < dests.size(); ++i) {
outputs.push_back({dests[i].layout, 0, dests[i].comp_node,
StorageIdentifier::make(i + 1)});
}
return {outputs, {}};
}
template <typename Op> template <typename Op>
SmallVector<TensorPtr> apply_on_physical_tensor( SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs) { const OpDef& def, const SmallVector<TensorPtr>& inputs) {
SmallVector<TensorPtr> outputs; SmallVector<TensorPtr> outputs;
SmallVector<LogicalTensorDesc> desc; SmallVector<LogicalTensorDesc> desc = infer_output_attrs<Op>(def, inputs);
desc = infer_output_attrs<Op>(def, inputs);
for (auto&& i : desc) { for (auto&& i : desc) {
outputs.push_back(Tensor::make(i.layout, i.comp_node)); outputs.push_back(Tensor::make(i.layout, i.comp_node));
} }
...@@ -454,10 +505,8 @@ void execute( ...@@ -454,10 +505,8 @@ void execute(
exec<Op>(def, inputs, outputs, {}); exec<Op>(def, inputs, outputs, {});
} }
template<typename Op> template <typename Op, typename Output>
SymbolVar apply_on_var_node( Output apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
const OpDef& def,
const VarNodeArray& inputs) {
size_t nr_inp = inputs.size(); size_t nr_inp = inputs.size();
constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS; constexpr size_t dnn_nr_inp = OpMeth<Op>::DnnOp::NR_INPUTS;
auto&& rng = def.cast_final_safe<Op>(); auto&& rng = def.cast_final_safe<Op>();
...@@ -487,7 +536,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible( ...@@ -487,7 +536,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return {{dest}, true}; return {{dest}, true};
} }
} // anonymous namespace template <>
std::tuple<SmallVector<LogicalTensorDesc>, bool>
infer_output_attrs_fallible<ShuffleRNG>(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
SmallVector<LogicalTensorDesc> dests(2);
dests[0].comp_node = inputs[0].comp_node;
dests[0].layout = TensorLayout(inputs[0].layout);
dests[0].layout.dtype = inputs[0].layout.dtype;
dests[1].comp_node = inputs[0].comp_node;
dests[1].layout = TensorLayout(TensorShape({inputs[0].layout.shape[0]}),
dtype::Int32());
return {dests, true};
}
} // anonymous namespace
Handle new_handle(CompNode comp_node, uint64_t seed) { Handle new_handle(CompNode comp_node, uint64_t seed) {
return RNGDnnOpManager::inst().new_handle(comp_node, seed); return RNGDnnOpManager::inst().new_handle(comp_node, seed);
...@@ -509,23 +572,24 @@ CompNode get_rng_handle_compnode(Handle handle){ ...@@ -509,23 +572,24 @@ CompNode get_rng_handle_compnode(Handle handle){
return RNGDnnOpManager::get_comp_node(handle); return RNGDnnOpManager::get_comp_node(handle);
} }
#define REG_RNG_OP(NAME)\ #define REG_RNG_OP(NAME, Output) \
namespace { \ namespace { \
OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \ OP_TRAIT_REG(NAME, NAME, OpMeth<NAME>::OpNode) \
.apply_on_var_node(apply_on_var_node<NAME>) \ .apply_on_var_node(apply_on_var_node<NAME, Output>) \
.apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \ .apply_on_physical_tensor(apply_on_physical_tensor<NAME>) \
.infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \ .infer_output_attrs_fallible(infer_output_attrs_fallible<NAME>) \
.infer_output_mem_desc(infer_output_mem_desc<NAME>) \ .infer_output_mem_desc(infer_output_mem_desc<NAME>) \
.execute(execute<NAME>) \ .execute(execute<NAME>) \
.fallback(); \ .fallback(); \
} \ }
REG_RNG_OP(UniformRNG) REG_RNG_OP(UniformRNG, SymbolVar)
REG_RNG_OP(GaussianRNG) REG_RNG_OP(GaussianRNG, SymbolVar)
REG_RNG_OP(GammaRNG) REG_RNG_OP(GammaRNG, SymbolVar)
REG_RNG_OP(PermutationRNG) REG_RNG_OP(PermutationRNG, SymbolVar)
REG_RNG_OP(PoissonRNG) REG_RNG_OP(PoissonRNG, SymbolVar)
REG_RNG_OP(BetaRNG) REG_RNG_OP(BetaRNG, SymbolVar)
REG_RNG_OP(ShuffleRNG, SymbolVarArray)
#undef REG_RNG_OP #undef REG_RNG_OP
} // namespace mgb::imperative::rng } // namespace mgb::imperative::rng
......
...@@ -215,6 +215,19 @@ def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> { ...@@ -215,6 +215,19 @@ def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> {
let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}]; let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}];
} }
def ShuffleRNG: MgbHashableOp<"ShuffleRNG", [ShuffleRNGParam]> {
let extraArguments = (ins
MgbSizeTAddr:$handle
);
let hashFunction = [{
return mgb::hash_pair_combine(
mgb::hash($_self.dyn_typeinfo()),
mgb::hash($_self.handle)
);
}];
let cmpFunction = [{return $0.handle == $1.handle;}];
}
def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> {
let extraArguments = (ins let extraArguments = (ins
MgbCompNodeAttr:$comp_node MgbCompNodeAttr:$comp_node
......
...@@ -192,6 +192,8 @@ template class RNGOprBase<::megdnn::GammaRNG>; ...@@ -192,6 +192,8 @@ template class RNGOprBase<::megdnn::GammaRNG>;
template class RNGOprBase<::megdnn::PermutationRNG>; template class RNGOprBase<::megdnn::PermutationRNG>;
template class RNGOprBase<::megdnn::BetaRNG>; template class RNGOprBase<::megdnn::BetaRNG>;
template class RNGOprBase<::megdnn::PoissonRNG>; template class RNGOprBase<::megdnn::PoissonRNG>;
template class RNGOprBase<::megdnn::ShuffleRNGForward>;
template class RNGOprBase<::megdnn::ShuffleRNGBackward>;
#if MGB_ENABLE_GRAD #if MGB_ENABLE_GRAD
IMPL(GaussianRNG); IMPL(GaussianRNG);
IMPL(UniformRNG); IMPL(UniformRNG);
...@@ -200,9 +202,87 @@ IMPL(PoissonRNG); ...@@ -200,9 +202,87 @@ IMPL(PoissonRNG);
IMPL(PermutationRNG); IMPL(PermutationRNG);
IMPL(BetaRNG); IMPL(BetaRNG);
#endif #endif
} } // namespace intl
} // namespace opr
} // namespace mgb
/* ================= ShuffleRNGForward ================= */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGForward);
ShuffleRNGForward::ShuffleRNGForward(VarNode* data, const Param& param,
const OperatorNodeConfig& config)
: Super({data->owner_graph(), config, "shuffle_rng", {data}}, param) {
add_input({data});
add_output(None)->dtype(data->dtype());
add_output(None)->dtype(dtype::Int32{});
cg::add_workspace_output(this);
add_equivalence_component<ScalarHash<void*>>(this);
}
SymbolVarArray ShuffleRNGForward::make(SymbolVar in_tensor, const Param& param,
const OperatorNodeConfig& config) {
auto node = in_tensor.node()->owner_graph()->insert_opr(
std::make_unique<ShuffleRNGForward>(in_tensor.node(), param,
config));
mgb_assert(node->output().size() == 3);
return {node->output(0), node->output(1)};
} }
void ShuffleRNGForward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0),
ShapeInferDesc::make_identity(input(0)));
auto infer_oshp1 = [this](TensorShape& dest, const InpVal& iv) {
TensorLayout o0, o1;
m_dnn_opr->deduce_layout({iv.val[0].shape(), input(0)->dtype()}, o0,
o1);
dest = o1;
return true;
};
mgr.register_shape_infer(
output(1),
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_oshp1});
auto infer_wk = [this](TensorShape& dest, const InpVal& inp) {
ensure_megdnn_opr();
dest.ndim = 1;
dest.shape[0] = m_dnn_opr->get_workspace_in_bytes(
{inp.val[0].shape(), input(0)->dtype()},
{output(0)->shape(), output(0)->dtype()},
{output(1)->shape(), output(1)->dtype()});
return true;
};
mgr.register_shape_infer(
output(2),
{SourceType::DEP, {{input(0), DepType::SHAPE}}, infer_wk});
} }
void ShuffleRNGForward::add_input_layout_constraint() {
input(0)->add_layout_constraint_contiguous();
};
void ShuffleRNGForward::scn_do_execute() {
m_dnn_opr->exec(input(0)->dev_tensor().as_megdnn(),
output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
get_megdnn_workspace_from_var(output(2)));
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(ShuffleRNGForward) {
mgb_assert(out_grad.size() == 3 && wrt_idx == 0 && !out_grad[2]);
if (!out_grad[0])
return nullptr;
return ShuffleRNGBackward::make(out_grad[0], opr.output(1), opr.input(0)).node();
}
#endif
MGB_DYN_TYPE_OBJ_FINAL_IMPL(ShuffleRNGBackward);
MEGDNN_OPR_INIT3(ShuffleRNGBackward, "shuffle_rng_bwd", 2, true)
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -6,7 +6,8 @@ ...@@ -6,7 +6,8 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#include "megbrain/opr/rand.h" #include "megbrain/opr/rand.h"
...@@ -14,6 +15,23 @@ ...@@ -14,6 +15,23 @@
namespace mgb { namespace mgb {
namespace serialization {
template <>
struct OprMaker<opr::ShuffleRNG, 1> {
using Opr = opr::ShuffleRNG;
using Param = Opr::Param;
static cg::OperatorNodeBase* make(const Param& param,
const cg::VarNodeArray& inputs,
ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
auto out = Opr::make(inputs[0], param, config);
return out[0].node()->owner_opr();
}
};
} // namespace serialization
namespace opr { namespace opr {
using UniformRNGV1 = opr::UniformRNG; using UniformRNGV1 = opr::UniformRNG;
...@@ -24,9 +42,10 @@ MGB_SEREG_OPR(GammaRNG, 2); ...@@ -24,9 +42,10 @@ MGB_SEREG_OPR(GammaRNG, 2);
MGB_SEREG_OPR(PoissonRNG, 1); MGB_SEREG_OPR(PoissonRNG, 1);
MGB_SEREG_OPR(PermutationRNG, 1); MGB_SEREG_OPR(PermutationRNG, 1);
MGB_SEREG_OPR(BetaRNG, 2); MGB_SEREG_OPR(BetaRNG, 2);
MGB_SEREG_OPR(ShuffleRNG, 1);
MGB_SEREG_OPR(ShuffleRNGBackward, 3);
} // namespace opr } // namespace opr
} // namespace mgb } // namespace mgb
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} // vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -6,14 +6,15 @@ ...@@ -6,14 +6,15 @@
* *
* Unless required by applicable law or agreed to in writing, * Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an * software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or
* implied.
*/ */
#pragma once #pragma once
#include "megbrain/graph.h" #include "megbrain/graph.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megdnn/oprs.h" #include "megdnn/oprs.h"
namespace mgb { namespace mgb {
...@@ -41,22 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // { ...@@ -41,22 +42,24 @@ MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // {
}; };
/* ================= RNG with shape ================= */ /* ================= RNG with shape ================= */
#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ #define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \
MGB_DEFINE_OPR_CLASS(RNG,RNGOprBase<megdnn::RNG>) \ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) \
cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \
public: \ \
RNG(VarNode *shape, const Param &param, const OperatorNodeConfig &config); \ public: \
static SymbolVar make(SymbolVar shape, const Param &param = {}, \ RNG(VarNode* shape, const Param& param, const OperatorNodeConfig& config); \
const OperatorNodeConfig &config = {}); \ static SymbolVar make(SymbolVar shape, const Param& param = {}, \
static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, \ const OperatorNodeConfig& config = {}); \
const OperatorNodeConfig &config, \ static SymbolVar make(ComputingGraph& graph, const TensorShape& shape, \
const Param &param = {}) { \ const OperatorNodeConfig& config, \
return make(var_from_tensor_shape(graph, config, "rng", shape), \ const Param& param = {}) { \
param, config); \ return make(var_from_tensor_shape(graph, config, "rng", shape), param, \
} \ config); \
void init_output_static_infer_desc() override; \ } \
void scn_do_execute() override; \ void init_output_static_infer_desc() override; \
}; void scn_do_execute() override; \
} \
;
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG)
_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG)
...@@ -71,7 +74,7 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) ...@@ -71,7 +74,7 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>)
public: \ public: \
RNG(_INPUTS(VarNode*), const Param &param, \ RNG(_INPUTS(VarNode*), const Param &param, \
const OperatorNodeConfig &config); \ const OperatorNodeConfig &config); \
static SymbolVar make(_INPUTS(SymbolVar),const Param &param = {}, \ static _OUTPUTS make(_INPUTS(SymbolVar),const Param &param = {}, \
const OperatorNodeConfig &config = {}); \ const OperatorNodeConfig &config = {}); \
void init_output_static_infer_desc() override; \ void init_output_static_infer_desc() override; \
void scn_do_execute() override; \ void scn_do_execute() override; \
...@@ -79,17 +82,24 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>) ...@@ -79,17 +82,24 @@ MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase<megdnn::RNG>)
/* ================= 1 input ================= */ /* ================= 1 input ================= */
#define _INPUTS(preifx) preifx i0 #define _INPUTS(preifx) preifx i0
#define _OUTPUTS SymbolVar
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(PoissonRNG) _DEFINE_RNG_OPR_WITH_INPUT_CLASS(PoissonRNG)
#undef _OUTPUTS
#define _OUTPUTS SymbolVarArray
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(ShuffleRNGForward)
#undef _OUTPUTS
#undef _INPUTS #undef _INPUTS
/* ================= 2 input ================= */ /* ================= 2 input ================= */
#define _INPUTS(preifx) preifx i0, preifx i1 #define _INPUTS(preifx) preifx i0, preifx i1
#define _OUTPUTS SymbolVar
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG) _DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG)
_DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) _DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG)
#undef _OUTPUTS
#undef _INPUTS #undef _INPUTS
#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS #undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS
} // intl } // intl
using UniformRNG = intl::UniformRNG; using UniformRNG = intl::UniformRNG;
using GaussianRNG = intl::GaussianRNG; using GaussianRNG = intl::GaussianRNG;
...@@ -97,9 +107,20 @@ using GammaRNG = intl::GammaRNG; ...@@ -97,9 +107,20 @@ using GammaRNG = intl::GammaRNG;
using PermutationRNG = intl::PermutationRNG; using PermutationRNG = intl::PermutationRNG;
using PoissonRNG = intl::PoissonRNG; using PoissonRNG = intl::PoissonRNG;
using BetaRNG = intl::BetaRNG; using BetaRNG = intl::BetaRNG;
} // namespace opr using ShuffleRNG = intl::ShuffleRNGForward;
} // namespace mgb
MGB_DEFINE_OPR_CLASS(ShuffleRNGBackward,
intl::MegDNNOprWrapperBwd<megdnn::ShuffleRNGBackward>) //{
public:
ShuffleRNGBackward(VarNode* out_diff, VarNode* indices, VarNode* result_shape,
const Param& param, const OperatorNodeConfig& config);
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} static SymbolVar make(SymbolVar out_diff, SymbolVar indices,
SymbolVar result_shape, const Param& param = {},
const OperatorNodeConfig& config = {});
};
} // namespace opr
} // namespace mgb
// vim: ft=cpp syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
...@@ -333,6 +333,38 @@ TEST(TestOprRand, EmptyShape) { ...@@ -333,6 +333,38 @@ TEST(TestOprRand, EmptyShape) {
} }
TEST(TestOprRand, ShuffleForward) {
auto run = [&](TensorShape shape) {
std::shared_ptr<HostTensorND> src_host(new HostTensorND{
CompNode::load("xpux"), shape, dtype::Float32()});
auto sptr = src_host->ptr<dt_float32>();
auto size = shape.total_nr_elems();
for (size_t i = 0; i < size; ++i) {
sptr[i] = i;
}
auto graph = ComputingGraph::make();
auto src_sym = opr::Host2DeviceCopy::make(*graph, src_host);
auto rec = opr::ShuffleRNG::make(src_sym, {10});
HostTensorND host_y, host_index;
auto func = graph->compile({make_callback_copy(rec[0], host_y),
make_callback_copy(rec[1], host_index)});
func->execute();
auto dptr = host_y.ptr<dt_float32>();
auto iptr = host_index.ptr<dt_int32>();
size_t len = shape[0];
size_t step = size / len;
for (size_t i = 0; i < len; ++i) {
for (size_t j = 0; j < step; ++j) {
assert(dptr[i * step + j] == sptr[iptr[i] * step + j]);
}
}
};
run({10});
run({6, 3});
run({1, 1});
}
TEST(TestOprRand, UniformReprod) { TEST(TestOprRand, UniformReprod) {
static constexpr size_t SIZE = 123; static constexpr size_t SIZE = 123;
auto graph = ComputingGraph::make(); auto graph = ComputingGraph::make();
......
...@@ -114,6 +114,7 @@ union OperatorParam { ...@@ -114,6 +114,7 @@ union OperatorParam {
param.BetaRNG = 80, param.BetaRNG = 80,
param.SlidingWindowTranspose = 81, param.SlidingWindowTranspose = 81,
param.Padding = 82, param.Padding = 82,
param.ShuffleRNG = 83,
} }
table Operator { table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册