From b078dda90bd0c95a5d4e9cabe05411ad0874b930 Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Mon, 7 Jun 2021 16:44:20 +0800 Subject: [PATCH] feat(mge/random): add some random op and remove random/distrbution.py GitOrigin-RevId: 4c05ebc2662408c098b0d5ddd42cc39b0d9d97f3 --- dnn/include/megdnn/oprs/utils.h | 60 +- dnn/scripts/opr_param_defs.py | 27 +- dnn/src/common/handle_impl.h | 4 + dnn/src/common/opr_trait.h | 4 + dnn/src/common/rng.cpp | 55 +- dnn/src/cuda/argsort/argsort.cu | 71 ++- dnn/src/cuda/argsort/argsort.cuh | 6 + dnn/src/cuda/rng/kernel.cu | 174 ++++++ dnn/src/cuda/rng/kernel.cuh | 258 +++++++++ dnn/src/cuda/rng/opr_impl.cpp | 139 +++++ dnn/src/cuda/rng/opr_impl.h | 141 ++++- dnn/src/naive/rng/opr_impl.cpp | 244 ++++++++ dnn/src/naive/rng/opr_impl.h | 117 ++-- dnn/test/cuda/rng.cpp | 186 +++++- dnn/test/naive/rng.cpp | 163 +++++- .../python/megengine/random/__init__.py | 15 +- .../python/megengine/random/distribution.py | 95 --- imperative/python/megengine/random/rng.py | 543 +++++++++++++++++- imperative/python/src/ops.cpp | 1 + .../python/test/unit/random/test_rng.py | 278 ++++++++- imperative/src/impl/ops/rng.cpp | 264 +++++++-- .../src/include/megbrain/imperative/ops/rng.h | 1 + imperative/src/test/rng.cpp | 62 +- src/core/include/megbrain/ir/ops.td | 71 ++- src/opr/impl/rand.cpp | 216 ++++--- src/opr/impl/rand.sereg.h | 4 + src/opr/include/megbrain/opr/rand.h | 102 ++-- src/opr/test/rand.cpp | 328 ++++++++--- src/serialization/impl/schema.fbs | 4 + 29 files changed, 3136 insertions(+), 497 deletions(-) create mode 100644 dnn/src/cuda/rng/kernel.cu create mode 100644 dnn/src/cuda/rng/kernel.cuh delete mode 100644 imperative/python/megengine/random/distribution.py diff --git a/dnn/include/megdnn/oprs/utils.h b/dnn/include/megdnn/oprs/utils.h index e2303dc7..ddf3cedf 100644 --- a/dnn/include/megdnn/oprs/utils.h +++ b/dnn/include/megdnn/oprs/utils.h @@ -21,19 +21,77 @@ class RNGBase: public OperatorBase { _megdnn_workspace workspace) = 0; virtual size_t get_workspace_in_bytes(const TensorLayout &dst) = 0; protected: - void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); + virtual void check_exec(const TensorLayout &dst, size_t workspace_in_bytes) = 0; +}; + +//! sample from poisson distribution +class PoissonRNG: public OperatorBase { + DEF_OPR_IMPL(PoissonRNG, OperatorBase, 1, 1); + DEF_OPR_PARAM(PoissonRNG); + public: + virtual void exec(_megdnn_tensor_in lam, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout &lam, + const TensorLayout &dst) = 0; + protected: + void check_exec(const TensorLayout &lam, const TensorLayout &dst, + size_t workspace_in_bytes); +}; + +//! sample from beta distribution +class BetaRNG: public OperatorBase { + DEF_OPR_IMPL(BetaRNG, OperatorBase, 2, 1); + DEF_OPR_PARAM(BetaRNG); + public: + virtual void exec(_megdnn_tensor_in alpha, + _megdnn_tensor_in beta, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout &alpha, + const TensorLayout &beta, const TensorLayout &dst) = 0; + protected: + void check_exec(const TensorLayout &alpha, const TensorLayout &beta, + const TensorLayout &dst, size_t workspace_in_bytes); +}; + +//! sample from gamma distribution +class GammaRNG: public OperatorBase { + DEF_OPR_IMPL(GammaRNG, OperatorBase, 2, 1); + DEF_OPR_PARAM(GammaRNG); + public: + virtual void exec(_megdnn_tensor_in shape, + _megdnn_tensor_in scale, + _megdnn_tensor_out dst, + _megdnn_workspace workspace) = 0; + virtual size_t get_workspace_in_bytes(const TensorLayout &shape, + const TensorLayout &scale, const TensorLayout &dst) = 0; + protected: + void check_exec(const TensorLayout &shape, const TensorLayout &scale, + const TensorLayout &dst, size_t workspace_in_bytes); }; //! sample from uniform distribution on the interval (0, 1] class UniformRNG: public RNGBase { DEF_OPR_IMPL(UniformRNG, RNGBase, 0, 1); DEF_OPR_PARAM(UniformRNG); + protected: + void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); }; //! sample from gaussian distribution class GaussianRNG: public RNGBase { DEF_OPR_IMPL(GaussianRNG, RNGBase, 0, 1); DEF_OPR_PARAM(GaussianRNG); + protected: + void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); +}; + +class PermutationRNG: public RNGBase { + DEF_OPR_IMPL(PermutationRNG, RNGBase, 0, 1); + DEF_OPR_PARAM(PermutationRNG); + protected: + void check_exec(const TensorLayout &dst, size_t workspace_in_bytes); }; /*! diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 566dc2f5..3a81e0ca 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -735,11 +735,34 @@ pdef('Sleep').add_fields('float32', Doc('time', 'time to sleep in seconds'), 0) 'dtype', Doc('dtype', 'data type of output value'), 'DTypeEnum::Float32')) -pdef('UniformRNG').add_fields('uint64', 'seed', 0) +(pdef('UniformRNG'). + add_fields('uint64', 'seed', 0). + add_fields( + 'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'), + 'DTypeEnum::Float32')) (pdef('GaussianRNG'). add_fields('uint64', 'seed', 0). - add_fields('float32', 'mean', 0, 'std', 1)) + add_fields('float32', 'mean', 0, 'std', 1). + add_fields( + 'dtype', Doc('dtype', 'The dtype of output Tensor. Only support Float32.'), + 'DTypeEnum::Float32')) + +(pdef('GammaRNG'). + add_fields('uint64', 'seed', 0)) + +(pdef('BetaRNG'). + add_fields('uint64', 'seed', 0)) + +(pdef('PoissonRNG'). + add_fields('uint64', 'seed', 0)) + +(pdef('PermutationRNG'). + add_fields('uint64', 'seed', 0). + add_fields( + 'dtype', Doc('dtype', 'The dtype of output Tensor. Int32, Int16 and ' + 'Float32 are supported.'), + 'DTypeEnum::Int32')) (pdef('Flip'). add_fields('bool', 'vertical', 'false', 'horizontal', 'false')) diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 03c83053..e2cc6121 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -159,6 +159,10 @@ private: cb(SleepForward) \ cb(UniformRNG) \ cb(GaussianRNG) \ + cb(GammaRNG) \ + cb(BetaRNG) \ + cb(PoissonRNG) \ + cb(PermutationRNG) \ cb(SeparableConvForward) \ cb(SeparableFilterForward) \ cb(BNForward) \ diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 646ba284..cb181582 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -120,6 +120,10 @@ DEF(TQTBackward, 5, true, false); DEF(PowC, 2, false, true); DEF(UniformRNG, 1, true, true); DEF(GaussianRNG, 1, true, true); +DEF(GammaRNG, 3, true, true); +DEF(BetaRNG, 3, true, true); +DEF(PoissonRNG, 2, true, true); +DEF(PermutationRNG, 1, true, true); DEF(ChecksumForward, 1, true, false); DEF(CheckHasInf, 2, true, true); DEF(LSQForward, 5, true, true); diff --git a/dnn/src/common/rng.cpp b/dnn/src/common/rng.cpp index eba6a9d2..fefb4add 100644 --- a/dnn/src/common/rng.cpp +++ b/dnn/src/common/rng.cpp @@ -15,13 +15,62 @@ namespace megdnn { -void RNGBase::check_exec( +void PermutationRNG::check_exec( const TensorLayout &dst, size_t workspace_in_bytes) { - megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && - dst.is_contiguous()); + megdnn_assert((dst.dtype == dtype::Float32() || + dst.dtype == dtype::Int32() || + dst.dtype == dtype::Int16() ) && + dst.dtype.enumv() == param().dtype && + dst.is_contiguous()); megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); } +void PoissonRNG::check_exec(const TensorLayout &lam, const TensorLayout &dst, + size_t workspace_in_bytes){ + megdnn_assert( dst.dtype.category() == DTypeCategory::FLOAT && + lam.dtype == dst.dtype); + megdnn_assert(dst.is_contiguous() && lam.is_contiguous()); + megdnn_assert(lam.total_nr_elems() == dst.total_nr_elems()); + megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(lam, dst)); +} + +void GammaRNG::check_exec(const TensorLayout &shape,const TensorLayout &scale, + const TensorLayout &dst, size_t workspace_in_bytes){ + megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && + shape.dtype == dst.dtype && + scale.dtype == dst.dtype); + megdnn_assert(shape.is_contiguous() && scale.is_contiguous() + && dst.is_contiguous()); + megdnn_assert(shape.total_nr_elems() == dst.total_nr_elems() && + scale.total_nr_elems() == dst.total_nr_elems()); + megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(shape,scale,dst)); +} + +void BetaRNG::check_exec(const TensorLayout &alpha,const TensorLayout &beta, + const TensorLayout &dst, size_t workspace_in_bytes){ + megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && + alpha.dtype == dst.dtype && + beta.dtype == dst.dtype); + megdnn_assert(alpha.is_contiguous() && beta.is_contiguous() + && dst.is_contiguous()); + megdnn_assert(alpha.total_nr_elems() == dst.total_nr_elems() && + beta.total_nr_elems() == dst.total_nr_elems()); + megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(alpha,beta, dst)); +} + +#define INST_CHECK_EXEC(RNG_NAME) \ + void RNG_NAME::check_exec( \ + const TensorLayout &dst, size_t workspace_in_bytes) { \ + megdnn_assert(dst.dtype.category() == DTypeCategory::FLOAT && \ + dst.dtype.enumv() == param().dtype && \ + dst.is_contiguous()); \ + megdnn_assert(workspace_in_bytes >= get_workspace_in_bytes(dst)); \ + } + +INST_CHECK_EXEC(UniformRNG) +INST_CHECK_EXEC(GaussianRNG) +#undef INST_CHECK_EXEC + } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/argsort/argsort.cu b/dnn/src/cuda/argsort/argsort.cu index a8781cf5..666cf62f 100644 --- a/dnn/src/cuda/argsort/argsort.cu +++ b/dnn/src/cuda/argsort/argsort.cu @@ -49,23 +49,42 @@ bool use_segmented(uint32_t M, uint32_t /*N*/) { return M >= 8; } -template -MEGDNN_NOINLINE size_t cub_sort_pairs( +__global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) { + uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; + if (i < n) { + dst[i] = i % mod; + } +} + +template +size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) { + if (use_bitonic(M, N)) { + return 0; + } + return argsort::cub_sort_pairs(is_ascending, NULL, 0, NULL, NULL, NULL, NULL, + M, N, 0, sizeof(float)*8, NULL); +} +} // anonymous namespace + +template +MEGDNN_NOINLINE size_t argsort::cub_sort_pairs( bool is_ascending, void* workspace, size_t workspace_size, - const KeyType* keys_in, KeyType* keys_out, const int* values_in, - int* values_out, uint32_t M, uint32_t N, cudaStream_t stream) { + const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, + ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,cudaStream_t stream){ cudaError_t err; if (use_segmented(M, N)) { if (is_ascending) { err = cub::DeviceSegmentedRadixSort::SortPairs( workspace, workspace_size, keys_in, keys_out, values_in, values_out, N * M, M, StridedOffsetIterator(0, N), - StridedOffsetIterator(N, N), 0, sizeof(float) * 8, stream); + StridedOffsetIterator(N, N), begin_bit, end_bit, stream); + cuda_check(err); } else { err = cub::DeviceSegmentedRadixSort::SortPairsDescending( workspace, workspace_size, keys_in, keys_out, values_in, values_out, N * M, M, StridedOffsetIterator(0, N), - StridedOffsetIterator(N, N), 0, sizeof(float) * 8, stream); + StridedOffsetIterator(N, N), begin_bit, end_bit, stream); + cuda_check(err); } } else { if (is_ascending) { @@ -73,7 +92,7 @@ MEGDNN_NOINLINE size_t cub_sort_pairs( err = cub::DeviceRadixSort::SortPairs( workspace, workspace_size, keys_in + N * i, keys_out + N * i, values_in + N * i, values_out + N * i, - N, 0, sizeof(float) * 8, stream); + N, begin_bit, end_bit, stream); cuda_check(err); if (!keys_in) { return workspace_size; @@ -84,7 +103,7 @@ MEGDNN_NOINLINE size_t cub_sort_pairs( err = cub::DeviceRadixSort::SortPairsDescending( workspace, workspace_size, keys_in + N * i, keys_out + N * i, values_in + N * i, values_out + N * i, - N, 0, sizeof(float) * 8, stream); + N, begin_bit, end_bit, stream); cuda_check(err); if (!keys_in) { return workspace_size; @@ -95,23 +114,6 @@ MEGDNN_NOINLINE size_t cub_sort_pairs( return workspace_size; } -__global__ void kern_arange(int* dst, uint32_t n, uint32_t mod) { - uint32_t i = threadIdx.x + blockIdx.x * blockDim.x; - if (i < n) { - dst[i] = i % mod; - } -} - -template -size_t get_sort_workspace(uint32_t M, uint32_t N, bool is_ascending) { - if (use_bitonic(M, N)) { - return 0; - } - return cub_sort_pairs(is_ascending, NULL, 0, NULL, NULL, NULL, NULL, - M, N, NULL); -} -} // anonymous namespace - size_t argsort::get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, bool is_ascending, bool iptr_src_given) { @@ -151,17 +153,28 @@ void argsort::forward(const dtype* sptr, dtype* dptr, int* iptr, stream)); } else { cub_sort_pairs(is_ascending, workspace, wk_size, sptr, dptr, iptr_src, - iptr, M, N, stream); + iptr, M, N, 0, sizeof(float)*8, stream); } } namespace megdnn { namespace cuda { + +#define INST_CUB_SORT(dtype) \ +template MEGDNN_NOINLINE size_t argsort::cub_sort_pairs(bool, \ + void*, size_t, const dtype*, dtype*, \ + const dtype*, dtype*, uint32_t, uint32_t,\ + int, int, cudaStream_t); + #define INST_FORWARD(dtype) \ - template void argsort::forward(const dtype*, dtype*, int*, void*, \ - uint32_t, uint32_t, bool, \ - cudaStream_t, const int*); +template void argsort::forward(const dtype*, dtype*, int*, void*, \ + uint32_t, uint32_t, bool, cudaStream_t, \ + const int*); + ARGSORT_FOREACH_CTYPE(INST_FORWARD) +INST_CUB_SORT(uint32_t) +INST_CUB_SORT(uint64_t) +#undef INST_CUB_SORT #undef INST_FORWARD } } // namespace megdnn diff --git a/dnn/src/cuda/argsort/argsort.cuh b/dnn/src/cuda/argsort/argsort.cuh index d6301d65..77c02078 100644 --- a/dnn/src/cuda/argsort/argsort.cuh +++ b/dnn/src/cuda/argsort/argsort.cuh @@ -24,6 +24,12 @@ size_t get_fwd_workspace_in_bytes(uint32_t M, uint32_t N, DType dtype, bool is_ascending, bool iptr_src_given = false); +template +size_t cub_sort_pairs( + bool is_ascending, void* workspace, size_t workspace_size, + const KeyType* keys_in, KeyType* keys_out, const ValueType* values_in, + ValueType* values_out, uint32_t M, uint32_t N, int begin_bit, int end_bit,cudaStream_t stream); + /*! * \param iptr_src pointer to indices; a range would be generated if it is null */ diff --git a/dnn/src/cuda/rng/kernel.cu b/dnn/src/cuda/rng/kernel.cu new file mode 100644 index 00000000..0db1e914 --- /dev/null +++ b/dnn/src/cuda/rng/kernel.cu @@ -0,0 +1,174 @@ +/** + * \file dnn/src/cuda/rnd/kernel.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#include +#include +#include "../argsort/argsort.cuh" +#include "./kernel.cuh" +#include "src/cuda/cuda_shfl_compat.cuh" +#include "src/cuda/utils.cuh" + +namespace megdnn { + +namespace cuda { + +namespace random { + +template +__global__ void permute_duplicate_keys_kernel(KeyType* keys, ValueType* indexs, + KeyType mask, size_t size, + uint64_t seed, uint64_t offset) { + uint32_t idx = threadIdx.x + blockDim.x * blockIdx.x; + if (idx >= size - 1) return; + uint32_t lane_idx = threadIdx.x & 0x1F; + KeyType cur_key = keys[idx] & mask; + + KeyType r_key = __shfl_down(cur_key, 1, 32); + if (lane_idx == 31) r_key = keys[idx + 1] & mask; + if (cur_key != r_key) return; + + KeyType l_key = __shfl_up(cur_key, 1, 32); + if (idx != 0 && lane_idx == 0) l_key = keys[idx - 1] & mask; + if (cur_key == l_key) return; + + indexs += idx; + int32_t duplicate_size = 1; + + for (; idx + duplicate_size < size && cur_key == (keys[idx + duplicate_size] & mask); + ++duplicate_size){}; + Philox state; + curand_init(seed, idx, offset, &state); + for (int32_t i = duplicate_size - 1; i > 0; --i) { + int32_t r = static_cast(curand(&state) & 0x7fffffff) % (i + 1); + if (i != r) { + ValueType tmp = indexs[i]; + indexs[i] = indexs[r]; + indexs[r] = tmp; + } + } +} + +uint32_t get_permutation_bits(size_t N) { + double uniq_rand_num_prob = 0.9; + double thresh = std::log(uniq_rand_num_prob) * 12; + double dN = static_cast(N); + uint32_t bits = std::min(64, static_cast(std::ceil(std::log2( + dN - (6 * dN * dN + 1) / thresh)))); + return bits; +} + +size_t get_permutation_workspace_in_bytes(size_t size) { + uint32_t bits = get_permutation_bits(size); + size_t work_size = 0; +#define cb(KeyType, ValueType) \ + size_t random_src_size = size * sizeof(KeyType); \ + size_t indexs_size = size * sizeof(ValueType); \ + size_t sort_worksize = argsort::cub_sort_pairs( \ + false, NULL, 0, NULL, NULL, NULL, NULL, 1, size, 0, bits, NULL); \ + work_size = 2 * random_src_size + 2 * indexs_size + \ + DIVUP(sort_worksize, sizeof(KeyType)) * sizeof(KeyType); + if (bits > 32) { + cb(uint64_t, uint64_t) + } else { + cb(uint32_t, uint32_t) + } +#undef cb + return work_size; +} + +template +void permutation_cuda(ctype* dst, void* workspace, size_t size, uint64_t seed, + uint64_t offset, uint32_t bits, cudaStream_t stream) { + int threads = 512; + int blocks = DIVUP(size, threads); + using KeyType = typename std::conditional::type; + using ValueType = KeyType; + + // split workspace + KeyType* keys_in = static_cast(workspace); + KeyType* keys_out = keys_in + size; + ValueType* values_in = static_cast(keys_out + size); + ValueType* values_out = values_in + size; + void* extra_workspace = static_cast(values_out + size); + + // init indexs + ElemwiseOpParamN<0> ele_param(size); + typedef RangeKernel rangeOp; + rangeOp range_op; + range_op.output = values_in; + run_elemwise(ele_param, stream, range_op); + + // generate random smaple + typedef RandomKernel randomOP; + randomOP random_op; + random_op.output = keys_in; + random_op.seed = seed; + random_op.offset = offset; + run_elemwise(ele_param, stream, random_op); + + // argsort random sample + size_t wk_size = argsort::cub_sort_pairs( + false, NULL, 0, NULL, NULL, NULL, NULL, 1, size, 0, bits, NULL); + argsort::cub_sort_pairs( + false, extra_workspace, wk_size, keys_in, keys_out, values_in, + values_out, 1, size, 0, bits, stream); + + // permute duplicate sample + KeyType mask = static_cast((1ULL << bits) - 1); + permute_duplicate_keys_kernel + <<>>(keys_out, values_out, mask, size, + seed, offset); + after_kernel_launch(); + + typedef AsTypeKernel asTypeOP; + asTypeOP as_type_op; + as_type_op.input = values_out; + as_type_op.output = dst; + run_elemwise(ele_param, stream, as_type_op); +} + +template +void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed, + uint64_t offset, cudaStream_t stream) { + uint32_t bits = get_permutation_bits(size); + if (bits <= 32) { + permutation_cuda(dst, workspace, size, seed, offset, bits, + stream); + } else { + permutation_cuda(dst, workspace, size, seed, offset, bits, + stream); + } +} + +#define INST_PERMUTATION(T) \ + template void permutation_forward(T*, void*, size_t, uint64_t, uint64_t, \ + cudaStream_t); \ + +INST_PERMUTATION(dt_int32) +INST_PERMUTATION(dt_int16) +INST_PERMUTATION(dt_float32) +#undef INST_PERMUTATION + +} // namespace random + +#define INST(_dtype) \ + INST_RUN_ELEMWISE(random::GammaKernel::ctype>, \ + DTypeTrait<_dtype>::ctype, 0); \ + INST_RUN_ELEMWISE(random::PoissonKernel::ctype>, \ + DTypeTrait<_dtype>::ctype, 0); \ + INST_RUN_ELEMWISE(random::BetaKernel::ctype>, \ + DTypeTrait<_dtype>::ctype, 0); \ + +INST(megdnn::dtype::Float32) +INST(megdnn::dtype::Float16) +INST(megdnn::dtype::BFloat16) +#undef INST +} // namespace cuda +} // namespace megdnn \ No newline at end of file diff --git a/dnn/src/cuda/rng/kernel.cuh b/dnn/src/cuda/rng/kernel.cuh new file mode 100644 index 00000000..1ba6cbe8 --- /dev/null +++ b/dnn/src/cuda/rng/kernel.cuh @@ -0,0 +1,258 @@ +/** + * \file dnn/src/cuda/rng/kernel.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ +#pragma once +#include +#include + +#include +#include + +#include "megdnn/dtype.h" +#include "src/cuda/elemwise_helper.cuh" +#include "src/cuda/utils.cuh" + +#if MEGDNN_CC_HOST +#include "megdnn/oprs.h" +#endif + +namespace megdnn { +namespace cuda { +namespace random { + +using Philox = curandStatePhilox4_32_10_t; + +QUALIFIERS float _curand_uniform(Philox *state){ + float r = curand_uniform(state); + if (r >= 1.0f) { + r = 0.0f; + } + return r; +} + +template +struct RandomKernel; + +template +using enable_64bit = typename std::enable_if::value && ((sizeof(ctype)) == 8)>::type; + +template +using enable_32bit = typename std::enable_if::value && ((sizeof(ctype)) <= 4)>::type; + +template +struct RandomKernel>{ + ctype* output; + uint64_t seed, offset; + uint64_t mask = static_cast(std::numeric_limits::max()); + __device__ void operator()(uint32_t idx){ + Philox local_state; + curand_init(seed, idx, offset, &local_state); + uint4 rand = curand4(&local_state); + uint64_t val = (static_cast(rand.x) << 32) | rand.y; + output[idx] = static_cast(val & mask); + } +#if MEGDNN_CC_HOST + RandomKernel(const ctype* output, uint64_t seed, uint64_t offset) + : output{output}, + seed{seed}, + offset{offset}{} +#endif +}; + +template +struct RandomKernel>{ + ctype* output; + uint64_t seed, offset; + uint32_t mask = static_cast(std::numeric_limits::max()); + __device__ void operator()(uint32_t idx){ + Philox local_state; + curand_init(seed, idx, offset, &local_state); + uint32_t val = curand(&local_state); + output[idx] = static_cast(val & mask); + } +#if MEGDNN_CC_HOST + RandomKernel(const ctype* output, uint64_t seed, uint64_t offset) + : output{output}, + seed{seed}, + offset{offset}{} +#endif +}; + +template +struct RangeKernel{ + ctype* output; + __device__ void operator()(uint32_t idx){ + output[idx] = static_cast(idx); + } +#if MEGDNN_CC_HOST + RangeKernel(const ctype* output) + : output{output}{} +#endif +}; + +template +struct AsTypeKernel{ + ctype_src* input; + ctype_dst* output; + using ctype_mask =typename std::conditional::value, ctype_dst, ctype_src>::type; + ctype_src mask = static_cast(std::numeric_limits::max()); + __device__ void operator()(uint32_t idx){ + output[idx] = static_cast(input[idx] & mask); + } +#if MEGDNN_CC_HOST + AsTypeKernel(const ctype_src* input, const ctype_dst* output) + : input{input}, output{output}{} +#endif +}; + +template +struct GammaKernel { + ctype* output; + ctype* shape; + ctype* scale; + uint64_t seed, offset; + + static __device__ float sample_gamma(float a, float b, Philox* state){ + float scale = b; + if (a <= 0) + return 0.f; + if (a < 1.0f) { + scale *= powf(_curand_uniform(state), 1.0f / a); + a += 1.0f; + } + float d = a - 1.0f / 3.0f; + float c = 1.0f / sqrtf(9.0f * d); + while (1) { + float x, y; + x = curand_normal(state); + y = 1.0f + c * x; + if (y <= 0) + continue; + + float v = y * y * y; + float u = _curand_uniform(state); + float xx = x * x; + + if ((u < 1.0f - 0.0331f * xx * xx) || + logf(u) < 0.5f * xx + d * (1.0f - v + logf(v))) + return scale * d * v; + } + } + + __device__ void operator()(uint32_t idx) { + Philox local_state; + curand_init(seed, idx, offset, &local_state); + float a = static_cast(shape[idx]); + float b = static_cast(scale[idx]); + output[idx] = static_cast(sample_gamma(a, b, &local_state)); + } + +#if MEGDNN_CC_HOST + GammaKernel(const TensorND& output, const TensorND& shape, + const TensorND& scale, uint64_t seed, uint64_t offset) + : output{output.ptr()}, + shape{shape.ptr()}, + scale{scale.ptr()}, + seed{seed}, + offset{offset}{} +#endif +}; + +template +struct PoissonKernel{ + ctype* output; + ctype* lambda; + uint64_t seed, offset; + + __device__ void operator()(uint32_t idx){ + Philox local_state; + curand_init(seed, idx, offset, &local_state); + float lam = static_cast(lambda[idx]); + output[idx] = static_cast(curand_poisson(&local_state, lam)); + } + +#if MEGDNN_CC_HOST + PoissonKernel(const TensorND& output,const TensorND& lambda, + uint64_t seed, uint64_t offset) + : output{output.ptr()}, + lambda{lambda.ptr()}, + seed{seed}, + offset{offset}{} +#endif +}; + +template +struct BetaKernel{ + ctype* output; + ctype* alpha; + ctype* beta; + uint64_t seed, offset; + + __device__ void operator()(uint32_t idx){ + Philox local_state; + curand_init(seed, idx, offset, &local_state); + float a = static_cast(alpha[idx]); + float b = static_cast(beta[idx]); + if(a <= 0 || b <= 0){ + output[idx] = 0; + return; + } + if( a < 1.0f && b < 1.0f){ + float u, v, x, y; + while (true) + { + u = _curand_uniform(&local_state); + v = _curand_uniform(&local_state); + x = powf(u, 1.0f / a); + y = powf(v, 1.0f / b); + if (x + y < 1.0f) { + if (x + y > 0) { + output[idx] = static_cast(x / (x + y)); + return ; + } else { + float logx = logf(u) / a; + float logy = logf(v) / b; + float log_max = logx > logy ? logx : logy; + logx -= log_max; + logy -= log_max; + output[idx] = static_cast(exp(logx - + log(exp(logx) + exp(logy)))); + return ; + } + } + } + }else{ + float ga = GammaKernel::sample_gamma(a, 1.0f, &local_state); + float gb = GammaKernel::sample_gamma(b, 1.0f, &local_state); + output[idx] = static_cast(ga / ( ga + gb)); + return ; + } + } + +#if MEGDNN_CC_HOST + BetaKernel(const TensorND& output, const TensorND& alpha, + const TensorND& beta, uint64_t seed, uint64_t offset) + : output{output.ptr()}, + alpha{alpha.ptr()}, + beta{beta.ptr()}, + seed{seed}, + offset{offset}{} +#endif +}; + +template +void permutation_forward(ctype* dst, void* workspace, size_t size, uint64_t seed, + uint64_t offset, cudaStream_t stream); + +size_t get_permutation_workspace_in_bytes(size_t N); + +} // namespace random +} // namespace cuda +} // namespace megdnn diff --git a/dnn/src/cuda/rng/opr_impl.cpp b/dnn/src/cuda/rng/opr_impl.cpp index 4571733b..d31b8d82 100644 --- a/dnn/src/cuda/rng/opr_impl.cpp +++ b/dnn/src/cuda/rng/opr_impl.cpp @@ -13,6 +13,7 @@ #include "src/cuda/handle.h" #include "src/cuda/utils.h" #include "./opr_impl.h" +#include "./kernel.cuh" using namespace megdnn; using namespace cuda; @@ -122,5 +123,143 @@ size_t GaussianRNGImpl::get_workspace_in_bytes(const TensorLayout &layout) { return 0; } +GammaRNGImpl::GammaRNGImpl(Handle *handle): + GammaRNG(handle), + m_seed(0), + m_offset(0), + m_stream(cuda_stream(handle)) +{ +} + +void GammaRNGImpl::exec(_megdnn_tensor_in shape, _megdnn_tensor_in scale, + _megdnn_tensor_inout dst, _megdnn_workspace workspace) { + check_exec(shape.layout, scale.layout ,dst.layout, workspace.size); + auto size = dst.layout.total_nr_elems(); + megdnn_assert(size); + ensure_seed(m_param.seed); + ElemwiseOpParamN<0> ele_param(size); + switch (dst.layout.dtype.enumv()){ +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + { \ + using ctype = DTypeTrait<_dt>::ctype; \ + run_elemwise, ctype, 0>(ele_param, m_stream, \ + {dst, shape, scale, m_seed, m_offset}); \ + break ; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + default: + megdnn_throw("bad dtype"); +} + m_offset += 16; +} + +PoissonRNGImpl::PoissonRNGImpl(Handle *handle): + PoissonRNG(handle), + m_seed(0), + m_offset(0), + m_stream(cuda_stream(handle)) +{ +} + +void PoissonRNGImpl::exec(_megdnn_tensor_in lam, _megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec(lam.layout, dst.layout, workspace.size); + auto size = dst.layout.total_nr_elems(); + megdnn_assert(size); + ensure_seed(m_param.seed); + ElemwiseOpParamN<0> ele_param(size); + switch (dst.layout.dtype.enumv()){ +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + { \ + using ctype = DTypeTrait<_dt>::ctype; \ + run_elemwise, ctype, 0>(ele_param, m_stream, \ + {dst, lam, m_seed, m_offset}); \ + break; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + default: + megdnn_throw("bad dtype"); +} + m_offset += 20; +} + +BetaRNGImpl::BetaRNGImpl(Handle *handle): + BetaRNG(handle), + m_seed(0), + m_offset(0), + m_stream(cuda_stream(handle)) +{ +} + +void BetaRNGImpl::exec(_megdnn_tensor_in alpha, _megdnn_tensor_in beta,_megdnn_tensor_out dst, + _megdnn_workspace workspace) { + check_exec(alpha.layout, beta.layout ,dst.layout, workspace.size); + auto size = dst.layout.total_nr_elems(); + megdnn_assert(size); + ensure_seed(m_param.seed); + ElemwiseOpParamN<0> ele_param(size); + switch (dst.layout.dtype.enumv()){ +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + { \ + using ctype = DTypeTrait<_dt>::ctype; \ + run_elemwise, ctype, 0>(ele_param, m_stream, \ + {dst, alpha, beta, m_seed, m_offset}); \ + break; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + default: + megdnn_throw("bad dtype"); +} + m_offset += 32; +} + +PermutationRNGImpl::PermutationRNGImpl(Handle *handle): + PermutationRNG(handle), + m_seed(0), + m_offset(0), + m_stream(cuda_stream(handle)) +{ +} + +void PermutationRNGImpl::exec( + _megdnn_tensor_inout dst, _megdnn_workspace workspace) { + check_exec(dst.layout, workspace.size); + auto size = dst.layout.total_nr_elems(); + megdnn_assert(size); + ensure_seed(m_param.seed); + + auto wk = workspace.ptr(); + switch (dst.layout.dtype.enumv()){ +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + { \ + using ctype = DTypeTrait<_dt>::ctype; \ + ctype max_size = DTypeTrait<_dt>::max() - 1; \ + megdnn_assert(ctype(size) < max_size); \ + random::permutation_forward(dst.ptr(), wk, size, m_seed, \ + m_offset, m_stream); \ + break; \ + } + cb(::megdnn::dtype::Float32) + cb(::megdnn::dtype::Int32) + cb(::megdnn::dtype::Int16) +#undef cb + default: + megdnn_throw("bad dtype"); +} + m_offset += 8; +} + +size_t PermutationRNGImpl::get_workspace_in_bytes(const TensorLayout &layout){ + size_t size = layout.total_nr_elems(); + return random::get_permutation_workspace_in_bytes(size); +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/rng/opr_impl.h b/dnn/src/cuda/rng/opr_impl.h index 47033bb1..451c224f 100644 --- a/dnn/src/cuda/rng/opr_impl.h +++ b/dnn/src/cuda/rng/opr_impl.h @@ -10,9 +10,9 @@ */ #pragma once +#include #include "megdnn/oprs.h" #include "src/cuda/handle.h" -#include namespace megdnn { namespace cuda { @@ -22,51 +22,136 @@ class CuRandHandle { uint64_t m_seed; CuRandHandle(const CuRandHandle&) = delete; - CuRandHandle& operator = (const CuRandHandle&) = delete; + CuRandHandle& operator=(const CuRandHandle&) = delete; - public: - CuRandHandle(cudaStream_t stream, uint64_t seed = 0); - ~CuRandHandle(); +public: + CuRandHandle(cudaStream_t stream, uint64_t seed = 0); + ~CuRandHandle(); - void seed(uint64_t seed); + void seed(uint64_t seed); - curandGenerator_t gen() const { - return m_gen; - } + curandGenerator_t gen() const { return m_gen; } - void ensure_seed(uint64_t seed) { - if (m_seed != seed) { - this->seed(seed); - } + void ensure_seed(uint64_t seed) { + if (m_seed != seed) { + this->seed(seed); } + } }; -class UniformRNGImpl: public UniformRNG { +class UniformRNGImpl : public UniformRNG { CuRandHandle m_curand_handle; - public: - UniformRNGImpl(Handle *handle); - void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; +public: + UniformRNGImpl(Handle* handle); + void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&) override { - return 0; - } + size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } }; -class GaussianRNGImpl: public GaussianRNG { +class GaussianRNGImpl : public GaussianRNG { CuRandHandle m_curand_handle; - public: - GaussianRNGImpl(Handle *handle); +public: + GaussianRNGImpl(Handle* handle); + + void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout& layout) override; +}; + +class GammaRNGImpl : public GammaRNG { + uint64_t m_seed, m_offset; + cudaStream_t m_stream; + +public: + GammaRNGImpl(Handle* handle); + void exec(_megdnn_tensor_in shape,_megdnn_tensor_in scale, + _megdnn_tensor_out dst, _megdnn_workspace) override; - void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } - size_t get_workspace_in_bytes(const TensorLayout &layout) override; + void seed(uint64_t seed) { m_seed = seed; } + + void ensure_seed(uint64_t seed) { + if (m_seed != seed) { + this->seed(seed); + } + } }; +class BetaRNGImpl : public BetaRNG { + uint64_t m_seed, m_offset; + cudaStream_t m_stream; -} // namespace cuda -} // namespace megdnn -// vim: syntax=cpp.doxygen +public: + BetaRNGImpl(Handle* handle); + + void exec(_megdnn_tensor_in alpha,_megdnn_tensor_in beta, + _megdnn_tensor_out dst, _megdnn_workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&, const TensorLayout&, + const TensorLayout&) override { + return 0; + } + + void seed(uint64_t seed) { m_seed = seed; } + + void ensure_seed(uint64_t seed) { + if (m_seed != seed) { + this->seed(seed); + } + } +}; +class PoissonRNGImpl : public PoissonRNG { + uint64_t m_seed, m_offset; + cudaStream_t m_stream; + +public: + PoissonRNGImpl(Handle* handle); + + void exec(_megdnn_tensor_in lam, _megdnn_tensor_out dst, + _megdnn_workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&) override { + return 0; + } + + void seed(uint64_t seed) { m_seed = seed; } + + void ensure_seed(uint64_t seed) { + if (m_seed != seed) { + this->seed(seed); + } + } +}; + +class PermutationRNGImpl : public PermutationRNG { + uint64_t m_seed, m_offset; + cudaStream_t m_stream; + +public: + PermutationRNGImpl(Handle* handle); + + void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout& layout) override; + + void seed(uint64_t seed) { m_seed = seed; } + + void ensure_seed(uint64_t seed) { + if (m_seed != seed) { + this->seed(seed); + } + } +}; + +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/rng/opr_impl.cpp b/dnn/src/naive/rng/opr_impl.cpp index 97c9bd0c..3d24adb4 100644 --- a/dnn/src/naive/rng/opr_impl.cpp +++ b/dnn/src/naive/rng/opr_impl.cpp @@ -78,6 +78,157 @@ namespace { } } + template + T normal_sample(Xoroshiro128plus *rng){ + T v; + fill_gaussian(rng, &v, 1, T(0.f), T(1.f)); + return v; + } + + template + T uniform_sample(Xoroshiro128plus *rng){ + return uniform_int2float((*rng)()); + } + + template + void fill_gamma(Xoroshiro128plus *rng, U *dst, size_t size, + U* shape, U* scale){ + for(size_t i = 0; i < size; ++i){ + T a = static_cast(shape[i]); + T b = static_cast(scale[i]); + T scale = b; + bool a_less_one = a < 1.f ? true : false; + if (a <= 0) { + dst[i] = U(0.0f); + continue; + }; + T d = a + (a_less_one ? 2.0f / 3.0f : -1.0f / 3.0f); + T c = 1.0f / std::sqrt(9.0f * d); + while (true) + { + T x, y; + x = normal_sample(rng); + y = 1.0f + c * x; + if ( y <= 0) continue; + T v = y * y * y; + T u = uniform_sample(rng); + T xx = x * x; + if ((u < 1.0f - 0.0331f * xx * xx) || + std::log(u) < 0.5f * xx + d * (1.0f - v + std::log(v))) + { + dst[i] = U(scale * d * v); + if (a_less_one) dst[i] *= U(std::pow(uniform_sample(rng), T(1.f / a))); + break; + } + } + } + } + + template + void fill_poisson(Xoroshiro128plus *rng, U *dst, U* lam, size_t size){ + for(size_t i = 0; i < size; ++i) { + T lambda = static_cast(lam[i]); + T exp_neg_lambda = std::exp(-lambda); + T log_lambda = std::log(lambda), sqrt_lambda = std::sqrt(lambda); + T b = 0.931f + 2.53f * sqrt_lambda; + T a = -0.059f + 0.02483f * b; + T inv_alpha = 1.1239f + 1.1328f / ( b - 3.4f); + T vr = 0.9277f - 3.6224f / (b - 2.f); + T u , v, u_shifted, k; + if( lambda == 0) { + dst[i] = U(0); + continue; + } + if ( lambda < 10){ + T prod = 1, x = 0; + u = 0; + while (true) + { + u = uniform_sample(rng); + prod *= u; + if ( prod <= exp_neg_lambda ){ + dst[i] = U(x); + break; + } + x += 1; + } + continue; + } + while (true) + { + u = uniform_sample(rng) - T(0.5f); + v = uniform_sample(rng); + u_shifted = T(0.5f) - std::abs(u); + k = std::floor((T(2.f) * a / u_shifted + b) * u + lambda + T(0.43f)); + if ( u_shifted >= 0.07 && v < vr ){ + dst[i] = U(k); + break; + } + if (k < 0 || (u_shifted < T(0.013f) && v > u_shifted)) { + continue; + } + if ((std::log(v) + std::log(inv_alpha) - std::log(a / (u_shifted * u_shifted) + b)) <= + (-lambda + k * log_lambda - std::lgamma(k + 1))) { + dst[i] = U(k); + break; + } + } + } + } + + template + void fill_beta(Xoroshiro128plus *rng, U *dst, U* alpha,U* beta, size_t size){ + for (size_t i = 0; i < size; ++i) { + T a = static_cast(alpha[i]), b = static_cast(beta[i]); + if( a < 1.0f && b < 1.0f){ + T u,v,x,y; + while (true) + { + u = uniform_sample(rng); + v = uniform_sample(rng); + x = std::pow(u, 1.0f / a); + y = std::pow(v, 1.0f / b); + if (x + y < 1.0f) { + if (x + y > 0) { + dst[i] = static_cast(x / (x + y)); + break; + }else { + T logx = std::log(u) / a; + T logy = std::log(v) / b; + T log_max = std::max(logx, logy); + logx -= log_max; + logy -= log_max; + dst[i] = static_cast (std::exp(logx - + std::log(std::exp(logx) + std::exp(logy)))); + break; + } + } + } + }else{ + T ga, gb, one = 1; + fill_gamma(rng, &ga, 1, &a, &one); + fill_gamma(rng, &gb, 1, &b, &one); + dst[i] = static_cast( ga / (ga + gb)); + } + } + } + + template + void fill_permutation(Xoroshiro128plus *rng, T *dst, size_t size){ + const int64_t mask = std::numeric_limits::max(); + for (size_t i = 0; i < size; ++i) { + dst[i] = static_cast(i); + } + for (int64_t i = size - 1; i > 0; --i) { + int64_t r = static_cast((*rng)()&mask) % (i + 1); + if (i != r) { + T tmp = dst[i]; + dst[i] = dst[r]; + dst[r] = tmp; + } + } + } + } // anonymous namespace uint64_t Splitmix64::operator() () { @@ -150,5 +301,98 @@ void GaussianRNGImpl::exec( } } +void GammaRNGImpl::exec(_megdnn_tensor_in shape, _megdnn_tensor_in scale, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(shape.layout, scale.layout, dst.layout, workspace.size); + auto size = dst.layout.total_nr_elems(); + auto prng = &m_rng.ensure_seed(m_param.seed); + switch (dst.layout.dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + { \ + using ctype = DTypeTrait<_dt>::ctype; \ + auto ptr = dst.ptr(); \ + MEGDNN_DISPATCH_CPU_KERN_OPR({fill_gamma(prng, ptr, \ + size, shape.ptr(), scale.ptr());};); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + default: + megdnn_throw("bad dtype"); + } +} + +void PoissonRNGImpl::exec(_megdnn_tensor_in lam, + _megdnn_tensor_inout dst, _megdnn_workspace workspace) { + check_exec(lam.layout, dst.layout, workspace.size); + auto size = dst.layout.total_nr_elems(); + auto prng = &m_rng.ensure_seed(m_param.seed); + switch (dst.layout.dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + { \ + using ctype = DTypeTrait<_dt>::ctype; \ + auto dst_ptr = dst.ptr(); \ + auto lam_ptr = lam.ptr(); \ + MEGDNN_DISPATCH_CPU_KERN_OPR({fill_poisson(prng, dst_ptr, \ + lam_ptr, size );};); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + default: + megdnn_throw("bad dtype"); + } +} + +void BetaRNGImpl::exec(_megdnn_tensor_in alpha,_megdnn_tensor_in beta, + _megdnn_tensor_out dst, _megdnn_workspace workspace) { + check_exec(alpha.layout, beta.layout, dst.layout, workspace.size); + auto size = dst.layout.total_nr_elems(); + auto prng = &m_rng.ensure_seed(m_param.seed); + switch (dst.layout.dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + { \ + using ctype = DTypeTrait<_dt>::ctype; \ + auto dst_ptr = dst.ptr(); \ + MEGDNN_DISPATCH_CPU_KERN_OPR({fill_beta(prng, dst_ptr, \ + alpha.ptr(),beta.ptr(), size );};); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + default: + megdnn_throw("bad dtype"); + } +} + +void PermutationRNGImpl::exec( + _megdnn_tensor_inout dst, _megdnn_workspace workspace) { + check_exec(dst.layout, workspace.size); + auto size = dst.layout.total_nr_elems(); + auto prng = &m_rng.ensure_seed(m_param.seed); + switch (dst.layout.dtype.enumv()) { +#define cb(_dt) \ + case DTypeTrait<_dt>::enumv: \ + { \ + using ctype = DTypeTrait<_dt>::ctype; \ + ctype max_size = DTypeTrait<_dt>::max() - 1; \ + megdnn_assert((ctype(size) < max_size)); \ + auto ptr = dst.ptr(); \ + MEGDNN_DISPATCH_CPU_KERN_OPR({fill_permutation(prng, ptr, \ + size);};); \ + return; \ + } + cb(::megdnn::dtype::Float32) + cb(::megdnn::dtype::Int32) + cb(::megdnn::dtype::Int16) +#undef cb + default: + megdnn_throw("bad dtype"); + } +} + // vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/rng/opr_impl.h b/dnn/src/naive/rng/opr_impl.h index 6370843c..370c55d2 100644 --- a/dnn/src/naive/rng/opr_impl.h +++ b/dnn/src/naive/rng/opr_impl.h @@ -10,8 +10,8 @@ */ #pragma once -#include "megdnn/oprs.h" #include +#include "megdnn/oprs.h" namespace megdnn { namespace naive { @@ -19,12 +19,11 @@ namespace naive { //! see http://xoroshiro.di.unimi.it/splitmix64.c class Splitmix64 { uint64_t m_s; - public: - explicit Splitmix64(uint64_t seed = 0): - m_s{seed} - {} - uint64_t operator() (); +public: + explicit Splitmix64(uint64_t seed = 0) : m_s{seed} {} + + uint64_t operator()(); }; /*! @@ -36,51 +35,99 @@ class Xoroshiro128plus { return (x << k) | (x >> (64 - k)); } - public: - explicit Xoroshiro128plus(uint64_t seed = 0) { +public: + explicit Xoroshiro128plus(uint64_t seed = 0) { this->seed(seed); } + + //! reset state if seed changed + Xoroshiro128plus& ensure_seed(uint64_t seed) { + if (seed != m_init_seed) { this->seed(seed); } + return *this; + } - //! reset state if seed changed - Xoroshiro128plus& ensure_seed(uint64_t seed) { - if (seed != m_init_seed) { - this->seed(seed); - } - return *this; - } + //! set seed + void seed(uint64_t seed); + + uint64_t operator()(); +}; - //! set seed - void seed(uint64_t seed); +class UniformRNGImpl : public UniformRNG { + Xoroshiro128plus m_rng; - uint64_t operator() (); +public: + using UniformRNG::UniformRNG; + void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } }; -class UniformRNGImpl: public UniformRNG { +class GaussianRNGImpl : public GaussianRNG { Xoroshiro128plus m_rng; - public: - using UniformRNG::UniformRNG; - void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; +public: + using GaussianRNG::GaussianRNG; + void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; - size_t get_workspace_in_bytes(const TensorLayout&) override { - return 0; - } + size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } }; -class GaussianRNGImpl: public GaussianRNG { +class GammaRNGImpl : public GammaRNG { Xoroshiro128plus m_rng; - public: - using GaussianRNG::GaussianRNG; - void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; +public: + using GammaRNG::GammaRNG; - size_t get_workspace_in_bytes(const TensorLayout&) override { - return 0; - } + void exec(_megdnn_tensor_in shape,_megdnn_tensor_in scale, _megdnn_tensor_out dst, + _megdnn_workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&,const TensorLayout&, + const TensorLayout&) override { + return 0; + } }; +class PoissonRNGImpl : public PoissonRNG { + Xoroshiro128plus m_rng; + +public: + using PoissonRNG::PoissonRNG; -} // namespace naive -} // namespace megdnn -// vim: syntax=cpp.doxygen + void exec(_megdnn_tensor_in lam, _megdnn_tensor_inout dst, + _megdnn_workspace) override; + size_t get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +class BetaRNGImpl : public BetaRNG { + Xoroshiro128plus m_rng; + +public: + using BetaRNG::BetaRNG; + + void exec(_megdnn_tensor_in alpha,_megdnn_tensor_in beta, _megdnn_tensor_out dst, + _megdnn_workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&, + const TensorLayout&, + const TensorLayout&) override { + return 0; + } +}; + +class PermutationRNGImpl : public PermutationRNG { + Xoroshiro128plus m_rng; + +public: + using PermutationRNG::PermutationRNG; + + void exec(_megdnn_tensor_inout dst, _megdnn_workspace) override; + + size_t get_workspace_in_bytes(const TensorLayout&) override { return 0; } +}; +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/test/cuda/rng.cpp b/dnn/test/cuda/rng.cpp index a50d9613..77ba9602 100644 --- a/dnn/test/cuda/rng.cpp +++ b/dnn/test/cuda/rng.cpp @@ -8,36 +8,165 @@ * software distributed under the License is distributed on an * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. */ -#include "megdnn/oprs.h" -#include "test/cuda/fixture.h" #include "test/naive/rng.h" +#include "megdnn/oprs.h" #include "test/common/tensor.h" +#include "test/cuda/fixture.h" namespace megdnn { namespace test { +namespace { + +template +void run_gamma(Handle* handle) { + using ctype = typename DTypeTrait::ctype; + auto opr = handle->create_operator(); + + TensorLayout ly{TensorShape{2000000 * 5}, T()}; + + SyncedTensor out(handle, ly); + SyncedTensor shape(handle, ly); + SyncedTensor scale(handle, ly); + auto shape_ptr = shape.ptr_mutable_host(); + auto scale_ptr = scale.ptr_mutable_host(); + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 2000000; ++j) { + shape_ptr[i * 2000000 + j] =2 * 0.3 * i + 0.3; + scale_ptr[i * 2000000 + j] = i * 0.2 + 0.1; + } + } + + opr->exec(shape.tensornd_dev(), scale.tensornd_dev(), out.tensornd_dev(), + {}); + + auto ptr = out.ptr_mutable_host(); + for (int i = 0; i < 5; ++i) { + float a = 2 * 0.3 * i + 0.3, b = i * 0.2 + 0.1; + float mean = a *b; + float std = a * (b * b); + auto stat = get_mean_var(ptr + i * 2000000, 2000000, ctype(mean)); + ASSERT_LE(std::abs(stat.first - mean), 0.01); + ASSERT_LE(std::abs(stat.second - std), 0.01); + } +} + +template +void run_poisson(Handle* handle) { + using ctype = typename DTypeTrait::ctype; + auto opr = handle->create_operator(); + + TensorLayout ly{TensorShape{200000 * 5}, T()}; + + SyncedTensor out(handle, ly); + SyncedTensor lam(handle, ly); + auto lam_ptr = lam.ptr_mutable_host(); + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 200000; ++j) { + lam_ptr[i * 200000 + j] = ctype(i + 1); + } + } + opr->exec(lam.tensornd_dev(), out.tensornd_dev(), {}); + + auto ptr = out.ptr_mutable_host(); + for (int i = 0; i < 5; ++i) { + auto stat = get_mean_var(ptr + i * 200000, 200000, ctype(i + 1)); + ASSERT_LE(std::abs(stat.first - ctype(i + 1)), 0.01); + ASSERT_LE(std::abs(stat.second - ctype(i + 1)), 0.01); + } +} + +template +void run_beta(Handle* handle) { + using ctype = typename DTypeTrait::ctype; + auto opr = handle->create_operator(); + + TensorLayout ly{TensorShape{200000 * 5}, T()}; + + SyncedTensor out(handle, ly); + SyncedTensor alpha(handle, ly); + SyncedTensor beta(handle, ly); + auto alpha_ptr = alpha.ptr_mutable_host(); + auto beta_ptr = beta.ptr_mutable_host(); + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 200000; ++j) { + alpha_ptr[i * 200000 + j] = 0.3 * i + 0.1; + beta_ptr[i * 200000 + j] = 2 * i * 0.3 + 0.1; + } + } + + opr->exec(alpha.tensornd_dev(), beta.tensornd_dev(), out.tensornd_dev(), + {}); + + auto ptr = out.ptr_mutable_host(); + for (int i = 0; i < 5; ++i) { + float a = 0.3 * i + 0.1, b = 2 * i * 0.3 + 0.1; + float mean = a / (a + b); + float std = a * b / ((a + b) * (a + b) * (a + b + 1)); + auto stat = get_mean_var(ptr + i * 200000, 200000, ctype(mean)); + ASSERT_LE(std::abs(stat.first - mean), 0.01); + ASSERT_LE(std::abs(stat.second - std), 0.01); + } +} + +template +void run_permutation(Handle* handle) { + using ctype = typename DTypeTrait::ctype; + size_t sample_num = + std::min(200000, static_cast(DTypeTrait::max()) - 10); + + auto opr = handle->create_operator(); + opr->param().dtype = DTypeTrait::enumv; + TensorLayout ly{TensorShape{sample_num}, T()}; + Tensor workspace( + handle, + {TensorShape{opr->get_workspace_in_bytes(ly)}, dtype::Byte()}); + SyncedTensor t(handle, ly); + + opr->exec(t.tensornd_dev(), + {workspace.ptr(), workspace.layout().total_nr_elems()}); + + auto ptr = t.ptr_mutable_host(); + auto size = t.layout().total_nr_elems(); + + std::vector res(size); + int not_same = 0; + for (size_t i = 0; i < size; ++i) { + if ((ptr[i] - ctype(i)) >= ctype(1)) not_same++; + res[i] = ptr[i]; + } + ASSERT_GT(not_same, 5000); + std::sort(res.begin(), res.end()); + for (size_t i = 0; i < size; ++i) { + ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8); + } +} + +} // anonymous namespace + TEST_F(CUDA, UNIFORM_RNG_F32) { auto opr = handle_cuda()->create_operator(); + opr->param().dtype = DTypeTrait::enumv; SyncedTensor<> t(handle_cuda(), {TensorShape{200000}, dtype::Float32()}); opr->exec(t.tensornd_dev(), {}); - assert_uniform_correct(t.ptr_mutable_host(), - t.layout().total_nr_elems()); + assert_uniform_correct(t.ptr_mutable_host(), t.layout().total_nr_elems()); } TEST_F(CUDA, GAUSSIAN_RNG_F32) { auto opr = handle_cuda()->create_operator(); opr->param().mean = 0.8; opr->param().std = 2.3; - for (size_t size: {1, 200000, 200001}) { + opr->param().dtype = DTypeTrait::enumv; + for (size_t size : {1, 200000, 200001}) { TensorLayout ly{{size}, dtype::Float32()}; - Tensor workspace(handle_cuda(), - {TensorShape{opr->get_workspace_in_bytes(ly)}, - dtype::Byte()}); + Tensor workspace( + handle_cuda(), + {TensorShape{opr->get_workspace_in_bytes(ly)}, dtype::Byte()}); SyncedTensor<> t(handle_cuda(), ly); opr->exec(t.tensornd_dev(), - {workspace.ptr(), workspace.layout().total_nr_elems()}); + {workspace.ptr(), workspace.layout().total_nr_elems()}); auto ptr = t.ptr_mutable_host(); ASSERT_LE(std::abs(ptr[0] - 0.8), 2.3); @@ -50,10 +179,43 @@ TEST_F(CUDA, GAUSSIAN_RNG_F32) { } } -} // namespace test -} // namespace megdnn +TEST_F(CUDA, GAMMA_RNG_F32) { + run_gamma(handle_cuda()); +} -// vim: syntax=cpp.doxygen +TEST_F(CUDA, GAMMA_RNG_F16) { + run_gamma(handle_cuda()); +} +TEST_F(CUDA, POISSON_RNG_F32) { + run_poisson(handle_cuda()); +} + +TEST_F(CUDA, POISSON_RNG_F16) { + run_poisson(handle_cuda()); +} + +TEST_F(CUDA, BETA_RNG_F32) { + run_beta(handle_cuda()); +} + +TEST_F(CUDA, BETA_RNG_F16) { + run_beta(handle_cuda()); +} +TEST_F(CUDA, PERMUTATION_RNG_F32) { + run_permutation(handle_cuda()); +} + +TEST_F(CUDA, PERMUTATION_RNG_INT32) { + run_permutation(handle_cuda()); +} +TEST_F(CUDA, PERMUTATION_RNG_INT16) { + run_permutation(handle_cuda()); +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/naive/rng.cpp b/dnn/test/naive/rng.cpp index 6b837378..75a82223 100644 --- a/dnn/test/naive/rng.cpp +++ b/dnn/test/naive/rng.cpp @@ -32,6 +32,7 @@ namespace { template void run_uniform(Handle *handle) { auto opr = handle->create_operator(); + opr->param().dtype = DTypeTrait::enumv; Tensor::ctype> t( handle, {TensorShape{200000}, dtype()}); opr->exec(t.tensornd(), {}); @@ -44,6 +45,7 @@ namespace { auto opr = handle->create_operator(); opr->param().mean = 0.8; opr->param().std = 2.3; + opr->param().dtype = DTypeTrait::enumv; Tensor t(handle, {TensorShape{200001}, dtype()}); opr->exec(t.tensornd(), {}); @@ -53,8 +55,131 @@ namespace { ASSERT_LE(std::abs(ptr[i] - 0.8), ctype(15)); } auto stat = get_mean_var(ptr, size, ctype(0.8)); + ASSERT_LE(std::abs(stat.first - 0.8), 5e-3); - ASSERT_LE(std::abs(stat.second - 2.3 * 2.3), 5e-2); + ASSERT_LE(std::abs(stat.second - 2.3 * 2.3), 5e-2); + } + + template + void run_gamma(Handle* handle){ + + using ctype = typename DTypeTrait::ctype; + auto opr = handle->create_operator(); + + TensorLayout ly{TensorShape{2000000*5}, dtype()}; + + Tensor out(handle, ly); + Tensor shape(handle, ly); + Tensor scale(handle, ly); + + auto shape_ptr = shape.ptr(); + auto scale_ptr = scale.ptr(); + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 2000000; ++j) { + shape_ptr[i * 2000000 + j] = 2 * 0.3 * i + 0.5; + scale_ptr[i * 2000000 + j] = i * 0.2 + 0.1; + } + } + opr->exec(shape.tensornd(), scale.tensornd(), out.tensornd(), {}); + + auto ptr = out.ptr(); + for(int i = 0; i < 5 ; ++i){ + float a = 2 * 0.3 * i + 0.5, b = i * 0.2 + 0.1; + float mean = a * b; + float std = a * (b * b) ; + auto stat = get_mean_var(ptr + i * 2000000, 2000000, ctype(mean)); + ASSERT_LE(std::abs(stat.first - mean), 0.01); + ASSERT_LE(std::abs(stat.second - std), 0.01); + } + } + + template + void run_poisson(Handle* handle){ + + using ctype = typename DTypeTrait::ctype; + auto opr = handle->create_operator(); + + TensorLayout ly{TensorShape{200000*5}, dtype()}; + + Tensor out(handle, ly); + Tensor lam(handle, ly); + + auto lam_ptr = lam.ptr(); + for(int i = 0; i < 5; ++i){ + for(int j = 0; j <200000; ++j){ + lam_ptr[i*200000 + j] = ctype(i + 1); + } + } + opr->exec(lam.tensornd(), out.tensornd(), {}); + + auto ptr = out.ptr(); + for(int i = 0; i < 5 ; ++i){ + auto stat = get_mean_var(ptr + i*200000, 200000, ctype(i + 1)); + ASSERT_LE(std::abs(stat.first - ctype(i + 1)), 0.01); + ASSERT_LE(std::abs(stat.second - ctype(i + 1)), 0.01); + } + } + + template + void run_beta(Handle* handle){ + + using ctype = typename DTypeTrait::ctype; + auto opr = handle->create_operator(); + + TensorLayout ly{TensorShape{200000*5}, dtype()}; + + Tensor out(handle, ly); + Tensor alpha(handle, ly); + Tensor beta(handle, ly); + + auto alpha_ptr = alpha.ptr(); + auto beta_ptr = beta.ptr(); + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 200000; ++j) { + alpha_ptr[i * 200000 + j] = 0.3 * i + 0.1; + beta_ptr[i * 200000 + j] = 2 * i * 0.3 + 0.1; + } + } + opr->exec(alpha.tensornd(),beta.tensornd(), out.tensornd(), {}); + + auto ptr = out.ptr(); + for(int i = 0; i < 5 ; ++i){ + float a = 0.3 * i + 0.1, b = 2 * i * 0.3 + 0.1; + float mean = a / (a + b); + float std = a * b / ((a + b) * (a + b) * (a + b + 1)); + auto stat = get_mean_var(ptr + i * 200000, 200000, ctype(mean)); + ASSERT_LE(std::abs(stat.first - mean), 0.01); + ASSERT_LE(std::abs(stat.second - std), 0.01); + } + } + + template + void run_permutation(Handle* handle){ + + using ctype = typename DTypeTrait::ctype; + size_t sample_num = std::min(200000, + static_cast(DTypeTrait::max()) - 10); + + auto opr = handle->create_operator(); + opr->param().dtype = DTypeTrait::enumv; + TensorLayout ly{TensorShape{sample_num}, dtype()}; + Tensor t(handle, ly); + opr->exec(t.tensornd(), {}); + + auto ptr = t.ptr(); + auto size = t.layout().total_nr_elems(); + + std::vector res(size); + int not_same = 0; + for(size_t i = 0; i < size; ++i){ + if ((ptr[i] - ctype(i)) >= 1 ) not_same++; + res[i] = ptr[i]; + } + ASSERT_GT(not_same, 5000); + std::sort(res.begin(),res.end()); + for(size_t i = 0; i < size; ++i){ + ASSERT_LE(std::abs(res[i] - ctype(i)), 1e-8); + } } } @@ -74,6 +199,42 @@ TEST_F(NAIVE, GAUSSIAN_RNG_F16) { DNN_INC_FLOAT16(run_gaussian(handle())); } +TEST_F(NAIVE, GAMMA_RNG_F32) { + run_gamma(handle()); +} + +TEST_F(NAIVE, GAMMA_RNG_F16) { + DNN_INC_FLOAT16(run_gamma(handle())); +} + +TEST_F(NAIVE, POISSON_RNG_F32) { + run_poisson(handle()); +} + +TEST_F(NAIVE, POISSON_RNG_F16) { + DNN_INC_FLOAT16(run_poisson(handle())); +} + +TEST_F(NAIVE, BETA_RNG_F32) { + run_beta(handle()); +} + +TEST_F(NAIVE, BETA_RNG_F16) { + DNN_INC_FLOAT16(run_beta(handle())); +} + +TEST_F(NAIVE, PERMUTATION_RNG_F32) { + run_permutation(handle()); +} + +TEST_F(NAIVE, PERMUTATION_RNG_INT32) { + run_permutation(handle()); +} + +TEST_F(NAIVE, PERMUTATION_RNG_INT16) { + run_permutation(handle()); +} + } // namespace test } // namespace megdnn diff --git a/imperative/python/megengine/random/__init__.py b/imperative/python/megengine/random/__init__.py index 996be02b..e59a2a56 100644 --- a/imperative/python/megengine/random/__init__.py +++ b/imperative/python/megengine/random/__init__.py @@ -6,8 +6,17 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from .distribution import normal, uniform -from .rng import RNG, seed +from .rng import RNG, beta, gamma, normal, permutation, poisson, seed, uniform +__all__ = [ + "RNG", + "beta", + "gamma", + "normal", + "permutation", + "poisson", + "seed", + "uniform", +] # pylint: disable=undefined-variable -del distribution, rng # type: ignore[name-defined] +del rng # type: ignore[name-defined] diff --git a/imperative/python/megengine/random/distribution.py b/imperative/python/megengine/random/distribution.py deleted file mode 100644 index be74a0d6..00000000 --- a/imperative/python/megengine/random/distribution.py +++ /dev/null @@ -1,95 +0,0 @@ -# -*- coding: utf-8 -*- -# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") -# -# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -from typing import Iterable, Optional - -from .. import Tensor -from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed -from .rng import _normal, _uniform - -__all__ = ["normal", "uniform"] - - -def normal( - mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None -) -> Tensor: - r""" - Random variable with Gaussian distribution :math:`N(\mu, \sigma)`. - - :param size: output tensor size. - :param mean: the mean or expectation of the distribution. - :param std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`). - :return: the output tensor. - - Examples: - - .. testcode:: - - import megengine as mge - import megengine.random as rand - - x = rand.normal(mean=0, std=1, size=(2, 2)) - print(x.numpy()) - - Outputs: - - .. testoutput:: - :options: +SKIP - - [[-0.20235455 -0.6959438 ] - [-1.4939808 -1.5824696 ]] - - """ - return _normal( - mean=mean, - std=std, - size=size, - seed=_get_global_rng_seed(), - device=None, - handle=0, - ) - - -def uniform( - low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None -) -> Tensor: - r""" - Random variable with uniform distribution $U(0, 1)$. - - :param size: output tensor size. - :param low: lower range. - :param high: upper range. - :return: the output tensor. - - Examples: - - .. testcode:: - - import megengine as mge - import megengine.random as rand - - x = rand.uniform(size=(2, 2)) - print(x.numpy()) - - Outputs: - - .. testoutput:: - :options: +SKIP - - [[0.76901674 0.70496535] - [0.09365904 0.62957656]] - - """ - return _uniform( - low=low, - high=high, - size=size, - seed=_get_global_rng_seed(), - device=None, - handle=0, - ) diff --git a/imperative/python/megengine/random/rng.py b/imperative/python/megengine/random/rng.py index 448a232f..fe47a62c 100644 --- a/imperative/python/megengine/random/rng.py +++ b/imperative/python/megengine/random/rng.py @@ -6,8 +6,9 @@ # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +import collections import time -from typing import Iterable, Optional +from typing import Iterable, Optional, Union from numpy.random import MT19937 @@ -15,15 +16,97 @@ from .. import Tensor from ..core._imperative_rt.core2 import apply from ..core._imperative_rt.ops import delete_rng_handle as _delete_rng_handle from ..core._imperative_rt.ops import get_global_rng_seed as _get_global_rng_seed +from ..core._imperative_rt.ops import ( + get_rng_handle_compnode as _get_rng_handle_compnode, +) from ..core._imperative_rt.ops import new_rng_handle as _new_rng_handle from ..core._imperative_rt.ops import set_global_rng_seed as _set_global_rng_seed -from ..core.ops.builtin import GaussianRNG, UniformRNG +from ..core.ops.builtin import ( + BetaRNG, + GammaRNG, + GaussianRNG, + PermutationRNG, + PoissonRNG, + UniformRNG, +) from ..core.tensor import utils from ..device import get_default_device +__all__ = [ + "seed", + "RNG", + "uniform", + "normal", + "gamma", + "beta", + "poisson", + "permutation", +] + _rng = None +def _infer_broadcasted_shape(inps: Iterable[Tensor]) -> tuple: + broadcasted_ndim = inps[0].ndim + broadcasted_shape = list(inps[0]._tuple_shape) + for i in range(1, len(inps)): + cur_ndim = inps[i].ndim + cur_shape = list(inps[i]._tuple_shape) + n_dim = max(cur_ndim, broadcasted_ndim) + for j in range(n_dim - 1, -1, -1): + cur_dim = cur_ndim + j - n_dim + broad_dim = broadcasted_ndim + j - n_dim + cur_size = cur_shape[cur_dim] if cur_dim >= 0 else 1 + broad_size = broadcasted_shape[broad_dim] if broad_dim >= 0 else 1 + assert cur_size == broad_size or cur_size == 1 or broad_size == 1, ( + "The size of inps[{}] ({}) must match the size ({}) at " + "dim {}".format(i, cur_size, broad_size, j) + ) + broad_size = max(cur_size, broad_size) + if broad_dim < 0: + broadcasted_shape = [broad_size] + broadcasted_shape + broadcasted_ndim += 1 + else: + broadcasted_shape[broad_dim] = broad_size + return tuple(broadcasted_shape) + + +def _broadcast_tensors_with_size( + inps: Iterable[Tensor], size: Iterable[int] +) -> Iterable[Tensor]: + assert inps, "The inps cloud not be empty" + target_shape = _infer_broadcasted_shape(inps) + if isinstance(size, collections.abc.Iterable): + target_shape = tuple(size) + target_shape + target_ndim = len(target_shape) + for i in range(len(inps)): + if inps[i]._tuple_shape != target_shape: + inps[i] = ( + inps[i] + .reshape((1,) * (target_ndim - inps[i].ndim) + inps[i]._tuple_shape) + ._broadcast(target_shape) + ) + return inps + + +def _uniform( + low: float, + high: float, + size: Optional[Iterable[int]], + seed: int, + device: str, + handle: int, +) -> Tensor: + assert low < high, "Uniform is not defined when low >= high" + if size is None: + size = (1,) + op = UniformRNG(seed=seed, handle=handle, dtype="float32") + _ref = Tensor([], dtype="int32", device=device) + shape = utils.astensor1d(size, _ref, dtype="int32", device=device) + (output,) = apply(op, shape) + return low + (high - low) * output + + def _normal( mean: float, std: float, @@ -34,63 +117,477 @@ def _normal( ) -> Tensor: if size is None: size = (1,) - op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle) + op = GaussianRNG(seed=seed, mean=mean, std=std, handle=handle, dtype="float32") _ref = Tensor([], dtype="int32", device=device) shape = utils.astensor1d(size, _ref, dtype="int32", device=device) (output,) = apply(op, shape) return output -def _uniform( - low: float, - high: float, +def _gamma( + shape: Union[Tensor, float], + scale: Union[Tensor, float], size: Optional[Iterable[int]], seed: int, - device: str, handle: int, ) -> Tensor: - assert low < high, "Uniform is not defined when low >= high" - if size is None: - size = (1,) - op = UniformRNG(seed=seed, handle=handle) + handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle) + if not isinstance(shape, Tensor): + assert shape > 0, "Gamma is not defined when shape <= 0" + shape = Tensor(shape, dtype="float32", device=handle_cn) + if not isinstance(scale, Tensor): + assert scale > 0, "Gamma is not defined when scale <= 0" + scale = Tensor(scale, dtype="float32", device=handle_cn) + assert ( + handle_cn is None or handle_cn == shape.device + ), "The shape ({}) must be the same device with handle ({})".format( + shape.device, handle_cn + ) + assert ( + handle_cn is None or handle_cn == scale.device + ), "The scale ({}) must be the same device with handle ({})".format( + scale.device, handle_cn + ) + if isinstance(size, int) and size != 0: + size = (size,) + shape, scale = _broadcast_tensors_with_size([shape, scale], size) + op = GammaRNG(seed=seed, handle=handle) + (output,) = apply(op, shape, scale) + return output + + +def _beta( + alpha: Union[Tensor, float], + beta: Union[Tensor, float], + size: Optional[Iterable[int]], + seed: int, + handle: int, +) -> Tensor: + handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle) + if not isinstance(alpha, Tensor): + assert alpha > 0, "Beta is not defined when alpha <= 0" + alpha = Tensor(alpha, dtype="float32", device=handle_cn) + if not isinstance(beta, Tensor): + assert beta > 0, "Beta is not defined when beta <= 0" + beta = Tensor(beta, dtype="float32", device=handle_cn) + assert ( + handle_cn is None or handle_cn == alpha.device + ), "The alpha ({}) must be the same device with handle ({})".format( + alpha.device, handle_cn + ) + assert ( + handle_cn is None or handle_cn == beta.device + ), "The beta ({}) must be the same device with handle ({})".format( + beta.device, handle_cn + ) + if isinstance(size, int) and size != 0: + size = (size,) + alpha, beta = _broadcast_tensors_with_size([alpha, beta], size) + op = BetaRNG(seed=seed, handle=handle) + (output,) = apply(op, alpha, beta) + return output + + +def _poisson( + lam: Union[Tensor, float], size: Optional[Iterable[int]], seed: int, handle: int +) -> Tensor: + handle_cn = None if handle == 0 else _get_rng_handle_compnode(handle) + if not isinstance(lam, Tensor): + assert lam > 0, "Poisson is not defined when lam <= 0" + lam = Tensor(lam, dtype="float32", device=handle_cn) + if isinstance(size, int) and size != 0: + size = (size,) + assert ( + handle_cn is None or handle_cn == lam.device + ), "The lam ({}) must be the same device with handle ({})".format( + lam.device, handle_cn + ) + (lam,) = _broadcast_tensors_with_size([lam], size) + op = PoissonRNG(seed=seed, handle=handle) + (output,) = apply(op, lam) + return output + + +def _permutation(n: int, seed: int, device: str, handle: int, dtype: str) -> Tensor: + assert isinstance(n, int) and n > 0, "Permutation is not defined when n <= 0" + size = (n,) + op = PermutationRNG(seed=seed, handle=handle, dtype=dtype) _ref = Tensor([], dtype="int32", device=device) shape = utils.astensor1d(size, _ref, dtype="int32", device=device) (output,) = apply(op, shape) - return low + (high - low) * output + return output class RNG: - def __init__(self, seed=0, device=None): - self.seed = seed - self.device = device if device else get_default_device() - self.handle = _new_rng_handle(self.device, self.seed) + + r""" + :class:`RNG` exposes a number of methods for generating random numbers. + + :param seed: random seed used to initialize the pseudo-random number generator. + Default: None + :param device: the device of generated tensor. Default: None + + Examples: + + .. testcode:: + + import megengine.random as rand + rng = rand.RNG(seed=100) + x = rng.uniform(size=(2, 2)) + print(x.numpy()) + + Outputs: + + .. testoutput:: + :options: +SKIP + + [[0.84811664 0.6147553 ] + [0.59429836 0.64727545]] + + """ + + def __init__(self, seed: int = None, device: str = None): + self._device = device if device else get_default_device() + if seed is not None: + self._seed = seed + self._handle = _new_rng_handle(self._device, self._seed) + else: + self._seed = _get_global_rng_seed + self._handle = 0 + self._device = None def uniform( self, low: float = 0, high: float = 1, size: Optional[Iterable[int]] = None ): + r""" + Random variable with uniform distribution $U(0, 1)$. + + :param low: lower range. Default: 0 + :param high: upper range. Default: 1 + :param size: the size of output tensor. Default: None + :return: the output tensor. + + Examples: + + .. testcode:: + + import megengine as mge + import megengine.random as rand + + x = rand.uniform(size=(2, 2)) + print(x.numpy()) + + Outputs: + + .. testoutput:: + :options: +SKIP + + [[0.91600335 0.6680226 ] + [0.2046729 0.2769141 ]] + + """ + _seed = self._seed() if callable(self._seed) else self._seed return _uniform( low=low, high=high, size=size, - seed=self.seed, - device=self.device, - handle=self.handle, + seed=_seed, + device=self._device, + handle=self._handle, ) def normal( self, mean: float = 0, std: float = 1, size: Optional[Iterable[int]] = None ): + r""" + Random variable with Gaussian distribution :math:`N(\mu, \sigma)`. + + :param mean: the mean or expectation of the distribution. Default: 0 + :param std: the standard deviation of the distribution (variance = :math:`\sigma ^ 2`). + Default: 1 + :param size: the size of output tensor. Default: None + :return: the output tensor. + + Examples: + + .. testcode:: + + import megengine as mge + import megengine.random as rand + + x = rand.normal(mean=0, std=1, size=(2, 2)) + print(x.numpy()) + + Outputs: + + .. testoutput:: + :options: +SKIP + + [[-1.4010863 -0.9874344 ] + [ 0.56373274 0.79656655]] + + """ + _seed = self._seed() if callable(self._seed) else self._seed return _normal( mean=mean, std=std, size=size, - seed=self.seed, - device=self.device, - handle=self.handle, + seed=_seed, + device=self._device, + handle=self._handle, + ) + + def gamma( + self, + shape: Union[Tensor, float], + scale: Union[Tensor, float] = 1, + size: Optional[Iterable[int]] = None, + ): + r""" + Random variable with Gamma distribution :math:`\Gamma(k, \theta)`. + + The corresponding probability density function is + + .. math:: + p(x)=x^{k-1} \frac{e^{-x / \theta}}{\theta^{k} \Gamma(k)} + \quad \text { for } x>0 \quad k, \theta>0, + + where :math:`\Gamma(k)` is the gamma function, + + .. math:: + \Gamma(k)=(k-1) ! \quad \text { for } \quad k>0. + + :param shape: the shape parameter (sometimes designated "k") of the distribution. + Must be non-negative. + :param scale: the scale parameter (sometimes designated "theta") of the distribution. + Must be non-negative. Default: 1 + :param size: the size of output tensor. If shape and scale are scalars and given size is, e.g., + `(m, n)`, then the output shape is `(m, n)`. If shape or scale is a Tensor and given size + is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(shape, scale).shape`. + The broadcast rules are consistent with `numpy.broadcast`. Default: None + :return: the output tensor. + + Examples: + + .. testcode:: + + import megengine as mge + import megengine.random as rand + + x = rand.gamma(shape=2, scale=1, size=(2, 2)) + print(x.numpy()) + + shape = mge.Tensor([[ 1], + [10]], dtype="float32") + scale = mge.Tensor([1,5], dtype="float32") + + x = rand.gamma(shape=shape, scale=scale) + print(x.numpy()) + + x = rand.gamma(shape=shape, scale=scale, size=2) + print(x.numpy()) + + Outputs: + + .. testoutput:: + :options: +SKIP + + [[1.5064533 4.0689363 ] + [0.71639484 1.4551026 ]] + + [[ 0.4352188 11.399335 ] + [ 9.1888 52.009277 ]] + + [[[ 1.1726005 3.9654975 ] + [13.656933 36.559006 ]] + [[ 0.25848487 2.5540342 ] + [11.960409 21.031536 ]]] + + """ + _seed = self._seed() if callable(self._seed) else self._seed + return _gamma( + shape=shape, scale=scale, size=size, seed=_seed, handle=self._handle + ) + + def beta( + self, + alpha: Union[Tensor, float], + beta: Union[Tensor, float], + size: Optional[Iterable[int]] = None, + ): + r""" + Random variable with Beta distribution :math:`\operatorname{Beta}(\alpha, \beta)`. + + The corresponding probability density function is + + .. math:: + p(x)=\frac{1}{\mathrm{~B}(\alpha, \beta)} x^{\alpha-1}(1-x)^{\beta-1} + \quad \text { for } \alpha, \beta>0, + + where :math:`\mathrm{~B}(\alpha, \beta)` is the beta function, + + .. math:: + \mathrm{~B}(\alpha, \beta)=\int_{0}^{1} t^{\alpha-1}(1-t)^{\beta-1} d t. + + :param alpha: the alpha parameter of the distribution. Must be non-negative. + :param beta: the beta parameter of the distribution. Must be non-negative. + :param size: the size of output tensor. If alpha and beta are scalars and given size is, e.g., + `(m, n)`, then the output shape is `(m, n)`. If alpha or beta is a Tensor and given size + is, e.g., `(m, n)`, then the output shape is `(m, n) + broadcast(alpha, beta).shape`. + The broadcast rules are consistent with `numpy.broadcast`. Default: None + :return: the output tensor. + + Examples: + + .. testcode:: + + import megengine as mge + import megengine.random as rand + + x = rand.beta(alpha=2, beta=1, size=(2, 2)) + print(x.numpy()) + + alpha = mge.Tensor([[0.5], + [ 3]], dtype="float32") + beta = mge.Tensor([0.5,5], dtype="float32") + + x = rand.beta(alpha=alpha, beta=beta) + print(x.numpy()) + + x = rand.beta(alpha=alpha, beta=beta, size=2) + print(x.numpy()) + + Outputs: + + .. testoutput:: + :options: +SKIP + + [[0.582565 0.91763186] + [0.86963767 0.6088103 ]] + + [[0.41503012 0.16438372] + [0.90159506 0.47588003]] + + [[[0.55195075 0.01111084] + [0.95298755 0.25048104]] + [[0.11680304 0.13859665] + [0.997879 0.43259275]]] + + """ + _seed = self._seed() if callable(self._seed) else self._seed + return _beta(alpha=alpha, beta=beta, size=size, seed=_seed, handle=self._handle) + + def poisson(self, lam: Union[float, Tensor], size: Optional[Iterable[int]] = None): + r""" + Random variable with poisson distribution :math:`\operatorname{Poisson}(\lambda)`. + + The corresponding probability density function is + + .. math:: + f(k ; \lambda)=\frac{\lambda^{k} e^{-\lambda}}{k !}, + + where k is the number of occurrences :math:`({\displaystyle k=0,1,2...})`. + + :param lam: the lambda parameter of the distribution. Must be non-negative. + :param size: the size of output tensor. If lam is a scalar and given size is, e.g., `(m, n)`, + then the output shape is `(m, n)`. If lam is a Tensor with shape `(k, v)` and given + size is, e.g., `(m, n)`, then the output shape is `(m, n, k, v)`. Default: None. + :return: the output tensor. + + Examples: + + .. testcode:: + + import megengine as mge + import megengine.random as rand + + x = rand.poisson(lam=2., size=(1, 3)) + print(x.numpy()) + + lam = mge.Tensor([[1.,1.], + [10,10]], dtype="float32") + + x = rand.poisson(lam=lam) + print(x.numpy()) + + x = rand.poisson(lam=lam, size=(1,3)) + print(x.numpy()) + + Outputs: + + .. testoutput:: + :options: +SKIP + + [[3. 1. 3.]] + + [[ 2. 2.] + [12. 11.]] + + [[[[ 1. 1.] + [11. 4.]] + [[ 0. 0.] + [ 9. 13.]] + [[ 0. 1.] + [ 7. 12.]]]] + + """ + _seed = self._seed() if callable(self._seed) else self._seed + return _poisson(lam=lam, size=size, seed=_seed, handle=self._handle) + + def permutation(self, n: int, *, dtype: str = "int32"): + r""" + Generates a random permutation of integers from :math:`0` to :math:`n - 1`. + + :param n: the upper bound. Must be larger than 0. + :param dtype: the output data type. int32, int16 and float32 are + supported. Default: int32 + :return: the output tensor. + + Examples: + + .. testcode:: + + import megengine as mge + import megengine.random as rand + + x = rand.permutation(n=10, dtype="int32") + print(x.numpy()) + + x = rand.permutation(n=10, dtype="float32") + print(x.numpy()) + + Outputs: + + .. testoutput:: + :options: +SKIP + + [4 5 0 7 3 8 6 1 9 2] + [3. 4. 9. 0. 6. 8. 7. 1. 5. 2.] + + """ + _seed = self._seed() if callable(self._seed) else self._seed + return _permutation( + n=n, seed=_seed, device=self._device, handle=self._handle, dtype=dtype ) def __del__(self): - _delete_rng_handle(self.handle) + if self._handle != 0: + _delete_rng_handle(self._handle) + + +def _default_rng(): + r"""Default constructor for :class:`RNG`.""" + return RNG(seed=None, device=None) + + +_default_handle = _default_rng() + +uniform = _default_handle.uniform +normal = _default_handle.normal +gamma = _default_handle.gamma +beta = _default_handle.beta +poisson = _default_handle.poisson +permutation = _default_handle.permutation def _random_seed_generator(): diff --git a/imperative/python/src/ops.cpp b/imperative/python/src/ops.cpp index 22ed862e..840a617c 100644 --- a/imperative/python/src/ops.cpp +++ b/imperative/python/src/ops.cpp @@ -476,4 +476,5 @@ void init_ops(py::module m) { }, py::call_guard()); m.def("set_global_rng_seed", &rng::set_global_rng_seed); m.def("get_global_rng_seed", &rng::get_global_rng_seed); + m.def("get_rng_handle_compnode", &rng::get_rng_handle_compnode); } diff --git a/imperative/python/test/unit/random/test_rng.py b/imperative/python/test/unit/random/test_rng.py index b3bf7c20..a13eb877 100644 --- a/imperative/python/test/unit/random/test_rng.py +++ b/imperative/python/test/unit/random/test_rng.py @@ -9,8 +9,8 @@ import numpy as np import pytest -import megengine -from megengine import is_cuda_available, tensor +import megengine.functional as F +from megengine import Tensor from megengine.core._imperative_rt import CompNode from megengine.core._imperative_rt.core2 import apply from megengine.core._imperative_rt.ops import ( @@ -18,10 +18,16 @@ from megengine.core._imperative_rt.ops import ( get_global_rng_seed, new_rng_handle, ) -from megengine.core.ops.builtin import GaussianRNG, UniformRNG +from megengine.core.ops.builtin import ( + BetaRNG, + GammaRNG, + GaussianRNG, + PermutationRNG, + PoissonRNG, + UniformRNG, +) from megengine.distributed.helper import get_device_count_by_fork from megengine.random import RNG -from megengine.random.rng import _normal, _uniform @pytest.mark.skipif( @@ -34,22 +40,24 @@ def test_gaussian_op(): 11, 12, ) - shape = tensor(shape, dtype="int32") - op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0) + shape = Tensor(shape, dtype="int32") + op = GaussianRNG(seed=get_global_rng_seed(), mean=1.0, std=3.0, dtype="float32") (output,) = apply(op, shape) assert np.fabs(output.numpy().mean() - 1.0) < 1e-1 - assert np.sqrt(output.numpy().var()) - 3.0 < 1e-1 + assert np.fabs(np.sqrt(output.numpy().var()) - 3.0) < 1e-1 assert str(output.device) == str(CompNode("xpux")) + assert output.dtype == np.float32 cn = CompNode("xpu2") seed = 233333 h = new_rng_handle(cn, seed) - op = GaussianRNG(seed=seed, mean=3.0, std=1.0, handle=h) + op = GaussianRNG(seed=seed, mean=3.0, std=1.0, dtype="float32", handle=h) (output,) = apply(op, shape) delete_rng_handle(h) assert np.fabs(output.numpy().mean() - 3.0) < 1e-1 - assert np.sqrt(output.numpy().var()) - 1.0 < 1e-1 + assert np.fabs(np.sqrt(output.numpy().var()) - 1.0) < 1e-1 assert str(output.device) == str(cn) + assert output.dtype == np.float32 @pytest.mark.skipif( @@ -62,20 +70,138 @@ def test_uniform_op(): 11, 12, ) - shape = tensor(shape, dtype="int32") - op = UniformRNG(seed=get_global_rng_seed()) + shape = Tensor(shape, dtype="int32") + op = UniformRNG(seed=get_global_rng_seed(), dtype="float32") (output,) = apply(op, shape) assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 assert str(output.device) == str(CompNode("xpux")) + assert output.dtype == np.float32 cn = CompNode("xpu2") seed = 233333 h = new_rng_handle(cn, seed) - op = UniformRNG(seed=seed, handle=h) + op = UniformRNG(seed=seed, dtype="float32", handle=h) (output,) = apply(op, shape) delete_rng_handle(h) assert np.fabs(output.numpy().mean() - 0.5) < 1e-1 assert str(output.device) == str(cn) + assert output.dtype == np.float32 + + +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", +) +def test_gamma_op(): + _shape, _scale = 2, 0.8 + _expected_mean, _expected_std = _shape * _scale, np.sqrt(_shape) * _scale + + shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32") + scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32") + op = GammaRNG(seed=get_global_rng_seed(), handle=0) + (output,) = apply(op, shape, scale) + assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 + assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 + assert str(output.device) == str(CompNode("xpux")) + + cn = CompNode("xpu2") + seed = 233333 + h = new_rng_handle(cn, seed) + shape = F.full([8, 9, 11, 12], value=_shape, dtype="float32", device="xpu2") + scale = F.full([8, 9, 11, 12], value=_scale, dtype="float32", device="xpu2") + op = GammaRNG(seed=seed, handle=h) + (output,) = apply(op, shape, scale) + delete_rng_handle(h) + assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 + assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 + assert str(output.device) == str(cn) + + +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", +) +def test_beta_op(): + _alpha, _beta = 2, 0.8 + _expected_mean = _alpha / (_alpha + _beta) + _expected_std = np.sqrt( + _alpha * _beta / ((_alpha + _beta) ** 2 * (_alpha + _beta + 1)) + ) + + alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32") + beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32") + op = BetaRNG(seed=get_global_rng_seed()) + (output,) = apply(op, alpha, beta) + assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 + assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 + assert str(output.device) == str(CompNode("xpux")) + + cn = CompNode("xpu2") + seed = 233333 + h = new_rng_handle(cn, seed) + alpha = F.full([8, 9, 11, 12], value=_alpha, dtype="float32", device=cn) + beta = F.full([8, 9, 11, 12], value=_beta, dtype="float32", device=cn) + op = BetaRNG(seed=seed, handle=h) + (output,) = apply(op, alpha, beta) + delete_rng_handle(h) + assert np.fabs(output.numpy().mean() - _expected_mean) < 1e-1 + assert np.fabs(np.sqrt(output.numpy().var()) - _expected_std) < 1e-1 + assert str(output.device) == str(cn) + + +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", +) +def test_poisson_op(): + lam = F.full([8, 9, 11, 12], value=2, dtype="float32") + op = PoissonRNG(seed=get_global_rng_seed()) + (output,) = apply(op, lam) + assert np.fabs(output.numpy().mean() - 2.0) < 1e-1 + assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1 + assert str(output.device) == str(CompNode("xpux")) + + cn = CompNode("xpu2") + seed = 233333 + h = new_rng_handle(cn, seed) + lam = F.full([8, 9, 11, 12], value=2, dtype="float32", device=cn) + op = PoissonRNG(seed=seed, handle=h) + (output,) = apply(op, lam) + delete_rng_handle(h) + assert np.fabs(output.numpy().mean() - 2.0) < 1e-1 + assert np.fabs(np.sqrt(output.numpy().var()) - np.sqrt(2.0)) < 1e-1 + assert str(output.device) == str(cn) + + +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 2, reason="xpu counts need > 2", +) +def test_permutation_op(): + n = 1000 + + def test_permutation_op_dtype(dtype): + def sum_result(res, fun): + return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))]) + + shape = Tensor((n,), dtype="int32") + op = PermutationRNG(seed=get_global_rng_seed(), dtype=dtype) + (output,) = apply(op, shape) + assert sum_result(output, lambda x: x) < 500 + assert sum_result(output, np.sort) == n + assert str(output.device) == str(CompNode("xpux")) + assert output.dtype == dtype + + cn = CompNode("xpu2") + seed = 233333 + h = new_rng_handle(cn, seed) + op = PermutationRNG(seed=seed, handle=h, dtype=dtype) + (output,) = apply(op, shape) + delete_rng_handle(h) + assert sum_result(output, lambda x: x) < 500 + assert sum_result(output, np.sort) == n + assert str(output.device) == str(cn) + assert output.dtype == dtype + + test_permutation_op_dtype(np.float32) + test_permutation_op_dtype(np.int32) + test_permutation_op_dtype(np.int16) @pytest.mark.skipif( @@ -133,3 +259,131 @@ def test_NormalRNG(): assert all(out.shape.numpy() == np.array([20, 30, 40])) assert np.abs(out.mean().numpy() - mean) / std < 0.1 assert np.abs(np.std(out.numpy()) - std) < 0.1 + + +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", +) +def test_GammaRNG(): + m1 = RNG(seed=111, device="xpu0") + m2 = RNG(seed=111, device="xpu1") + m3 = RNG(seed=222, device="xpu0") + out1 = m1.gamma(2, size=(100,)) + out1_ = m1.uniform(size=(100,)) + out2 = m2.gamma(2, size=(100,)) + out3 = m3.gamma(2, size=(100,)) + + np.testing.assert_equal(out1.numpy(), out2.numpy()) + assert out1.device == "xpu0" and out2.device == "xpu1" + assert not (out1.numpy() == out3.numpy()).all() + assert not (out1.numpy() == out1_.numpy()).all() + + shape = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0") + scale = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0") + expected_mean = (shape * scale).numpy() + expected_std = (F.sqrt(shape) * scale).numpy() + out = m1.gamma(shape=shape, scale=scale, size=(20, 30, 40)) + out_shp = out.shape + if isinstance(out_shp, tuple): + assert out_shp == (20, 30, 40, 2, 3) + else: + assert all(out.shape.numpy() == np.array([20, 30, 40, 2, 3])) + assert ( + np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std + ).mean() < 0.1 + assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1 + + +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", +) +def test_BetaRNG(): + m1 = RNG(seed=111, device="xpu0") + m2 = RNG(seed=111, device="xpu1") + m3 = RNG(seed=222, device="xpu0") + out1 = m1.beta(2, 1, size=(100,)) + out1_ = m1.uniform(size=(100,)) + out2 = m2.beta(2, 1, size=(100,)) + out3 = m3.beta(2, 1, size=(100,)) + + np.testing.assert_equal(out1.numpy(), out2.numpy()) + assert out1.device == "xpu0" and out2.device == "xpu1" + assert not (out1.numpy() == out3.numpy()).all() + assert not (out1.numpy() == out1_.numpy()).all() + + alpha = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32, device="xpu0") + beta = Tensor([0.5, 1, 1.5], dtype=np.float32, device="xpu0") + expected_mean = (alpha / (alpha + beta)).numpy() + expected_std = ( + F.sqrt(alpha * beta / (F.pow(alpha + beta, 2) * (alpha + beta + 1))) + ).numpy() + out = m1.beta(alpha=alpha, beta=beta, size=(20, 30)) + out_shp = out.shape + if isinstance(out_shp, tuple): + assert out_shp == (20, 30, 2, 3) + else: + assert all(out.shape.numpy() == np.array([20, 30, 2, 3])) + assert ( + np.abs(out.mean(axis=(0, 1)).numpy() - expected_mean) / expected_std + ).mean() < 0.1 + assert (np.abs(np.std(out.numpy(), axis=(0, 1)) - expected_std)).mean() < 0.1 + + +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", +) +def test_PoissonRNG(): + m1 = RNG(seed=111, device="xpu0") + m2 = RNG(seed=111, device="xpu1") + m3 = RNG(seed=222, device="xpu0") + lam = Tensor([[2, 3, 4], [9, 10, 11]], dtype=np.float32) + out1 = m1.poisson(lam.to("xpu0"), size=(100,)) + out2 = m2.poisson(lam.to("xpu1"), size=(100,)) + out3 = m3.poisson(lam.to("xpu0"), size=(100,)) + + np.testing.assert_equal(out1.numpy(), out2.numpy()) + assert out1.device == "xpu0" and out2.device == "xpu1" + assert not (out1.numpy() == out3.numpy()).all() + + out = m1.poisson(lam.to("xpu0"), size=(20, 30)) + out_shp = out.shape + expected_shape = (20, 30) + lam._tuple_shape + if isinstance(out_shp, tuple): + assert out_shp == expected_shape + else: + assert all(out.shape.numpy() == np.array(expected_shape)) + lam = lam.numpy() + + assert (np.abs(out.mean(axis=(0, 1)).numpy() - lam) / np.sqrt(lam)).mean() < 0.1 + assert np.abs(np.std(out.numpy(), axis=(0, 1)) - np.sqrt(lam)).mean() < 0.1 + + +@pytest.mark.skipif( + get_device_count_by_fork("xpu") <= 1, reason="xpu counts need > 1", +) +def test_PermutationRNG(): + m1 = RNG(seed=111, device="xpu0") + m2 = RNG(seed=111, device="xpu1") + m3 = RNG(seed=222, device="xpu0") + out1 = m1.permutation(n=1000) + out1_ = m1.uniform(size=(1000,)) + out2 = m2.permutation(n=1000) + out3 = m3.permutation(n=1000) + + np.testing.assert_equal(out1.numpy(), out2.numpy()) + assert out1.device == "xpu0" and out2.device == "xpu1" + assert not (out1.numpy() == out3.numpy()).all() + assert not (out1.numpy() == out1_.numpy()).all() + + out = m1.permutation(n=1000) + out_shp = out.shape + if isinstance(out_shp, tuple): + assert out_shp == (1000,) + else: + assert all(out.shape.numpy() == np.array([1000])) + + def sum_result(res, fun): + return sum([1 if i == v else 0 for i, v in enumerate(fun(res.numpy()))]) + + assert sum_result(out, lambda x: x) < 500 + assert sum_result(out, np.sort) == 1000 diff --git a/imperative/src/impl/ops/rng.cpp b/imperative/src/impl/ops/rng.cpp index 0e6adbca..1e97a89d 100644 --- a/imperative/src/impl/ops/rng.cpp +++ b/imperative/src/impl/ops/rng.cpp @@ -176,6 +176,20 @@ struct OpMeth { using Param = DnnOp::Param; using OpNode = mgb::opr::UniformRNG; static Param make_param(const UniformRNG& rng) { + auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); + mgb_assert(handle_seed == rng.seed, + "inconsistent rng seed: rng op: %lu handle: %lu", + handle_seed, rng.seed); + return {handle_seed, rng.dtype.enumv()}; + } +}; + +template <> +struct OpMeth { + using DnnOp = megdnn::PoissonRNG; + using Param = DnnOp::Param; + using OpNode = mgb::opr::PoissonRNG; + static Param make_param(const PoissonRNG& rng) { auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); mgb_assert(handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", @@ -194,16 +208,168 @@ struct OpMeth { mgb_assert(handle_seed == rng.seed, "inconsistent rng seed: rng op: %lu handle: %lu", handle_seed, rng.seed); - return {handle_seed, rng.mean, rng.std}; + return {handle_seed, rng.mean, rng.std, rng.dtype.enumv()}; + } +}; + +template <> +struct OpMeth { + using DnnOp = megdnn::GammaRNG; + using Param = DnnOp::Param; + using OpNode = mgb::opr::GammaRNG; + static Param make_param(const GammaRNG& rng) { + auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); + mgb_assert(handle_seed == rng.seed, + "inconsistent rng seed: rng op: %lu handle: %lu", + handle_seed, rng.seed); + return {handle_seed}; } }; +template <> +struct OpMeth { + using DnnOp = megdnn::PermutationRNG; + using Param = DnnOp::Param; + using OpNode = mgb::opr::PermutationRNG; + static Param make_param(const PermutationRNG& rng) { + auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); + mgb_assert(handle_seed == rng.seed, + "inconsistent rng seed: rng op: %lu handle: %lu", + handle_seed, rng.seed); + return {handle_seed, rng.dtype.enumv()}; + } +}; + +template <> +struct OpMeth { + using DnnOp = megdnn::BetaRNG; + using Param = DnnOp::Param; + using OpNode = mgb::opr::BetaRNG; + static Param make_param(const BetaRNG& rng) { + auto handle_seed = RNGDnnOpManager::get_seed(rng.handle); + mgb_assert(handle_seed == rng.seed, + "inconsistent rng seed: rng op: %lu handle: %lu", + handle_seed, rng.seed); + return {handle_seed}; + } +}; + +template +struct _InferLayout; + +template +struct _RNGOprMaker; + +template +struct _RNGOprInvoker; + +template<> +struct _InferLayout +{ + template + static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){ + TensorShape tshape; + auto hv = inp->get_value().proxy_to_default_cpu(); + cg::copy_tensor_value_to_shape(tshape, hv); + return TensorLayout(tshape, rng.dtype); + } + + template + static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){ + TensorLayout out_layout = inp.layout; + out_layout.dtype = rng.dtype; + if (inp.layout.ndim == 0 || inp.value.empty()) { + out_layout.ndim = 0; + return out_layout; + } + mgb_assert( + inp.layout.ndim == 1, + "target shape of %s expects ndim=1; got ndim=%lu actually", + rng.dyn_typeinfo()->name, + inp.layout.ndim); + size_t target_ndim = inp.layout.shape[0]; + out_layout.ndim = target_ndim; + auto* ptr = inp.value.ptr(); + for (size_t i = 0; i < target_ndim; ++i) { + out_layout.shape[i] = ptr[i]; + } + return out_layout; + } +}; + +template<> +struct _InferLayout +{ + template + static TensorLayout do_infer(const TensorPtr& inp, const Op& rng){ + return inp->layout(); + } + + template + static TensorLayout do_infer(const LogicalTensorDesc& inp, const Op& rng){ + size_t size = inp.layout.total_nr_elems(); + mgb_assert( + size > 0, + "target size of %s expects size>0; got size=%lu actually", + rng.dyn_typeinfo()->name, + size); + return inp.layout; + } +}; + +#define _INST_RNG_INVOLKER(DNN_NR_INPUTS) \ +template<> \ +struct _RNGOprInvoker { \ + template \ + static void exec(Opr *dnn_op, const SmallVector& inputs,const TensorPtr& dest){ \ + size_t wk_size = 0; \ + wk_size = dnn_op->get_workspace_in_bytes(_FOR_EACH_IN(->layout())dest->layout()); \ + auto workspace = Blob::make(dest->comp_node(), wk_size); \ + megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); \ + dnn_op->exec(_FOR_EACH_IN(->dev_tensor().as_megdnn()) \ + dest->dev_tensor().as_megdnn(), dnn_wk); \ + } \ +}; + +#define _INST_RNG_MAKER(MGB_NR_INPUTS) \ +template<> \ +struct _RNGOprMaker { \ + template \ + static SymbolVar make(const VarNodeArray& inputs, const Op& rng){ \ + auto param = OpMeth::make_param(rng); \ + OperatorNodeConfig config; \ + if (rng.handle) { \ + config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; \ + } else { \ + config = {rng.make_name()}; \ + } \ + return OpMeth::OpNode::make(_FOR_EACH_IN() param, config); \ + } \ +}; + +#define _FOR_EACH_IN(subfix) +_INST_RNG_INVOLKER(0) +#undef _FOR_EACH_IN + +#define _FOR_EACH_IN(subfix) inputs[0] subfix, +_INST_RNG_INVOLKER(1) +_INST_RNG_MAKER(1) +#undef _FOR_EACH_IN + +#define _FOR_EACH_IN(subfix) inputs[0] subfix, inputs[1] subfix, +_INST_RNG_INVOLKER(2) +_INST_RNG_MAKER(2) +#undef _FOR_EACH_IN + +#undef _INST_RNG_INVOLKER +#undef _INST_RNG_MAKER + template void exec(const OpDef& op, const SmallVector& inputs, const SmallVector& outputs) { auto&& rng = op.cast_final_safe(); + auto dest = outputs[0]; - auto cn = dest->comp_node(); auto handle = rng.handle; if (!handle) { @@ -224,38 +390,40 @@ void exec(const OpDef& op, const SmallVector& inputs, handle_seed, dnn_op->param().seed); } dnn_op->param() = OpMeth::make_param(rng); - - // allocate workspace - size_t wk_size = dnn_op->get_workspace_in_bytes(dest->layout()); - auto workspace = Blob::make(cn, wk_size); - megdnn::Workspace dnn_wk(workspace->storage().get(), wk_size); - - dnn_op->exec(dest->dev_tensor().as_megdnn(), dnn_wk); + _RNGOprInvoker::DnnOp::NR_INPUTS>::exec(dnn_op,inputs,dest); } template SmallVector infer_output_attrs( const OpDef& op, const SmallVector& inputs) { LogicalTensorDesc dest; - auto handle = op.cast_final_safe().handle; + auto&& rng = op.cast_final_safe(); + auto handle = rng.handle; if (handle) { dest.comp_node = RNGDnnOpManager::get_comp_node(handle); } else { dest.comp_node = inputs[0]->comp_node(); } - - auto hv = inputs[0]->get_value().proxy_to_default_cpu(); - TensorShape tshape; - cg::copy_tensor_value_to_shape(tshape, hv); - dest.layout = TensorLayout(tshape, dtype::Float32()); + constexpr bool rng_with_shape = OpMeth::DnnOp::NR_INPUTS == 0; + if(!rng_with_shape){ + for(int i = 0; i < inputs.size(); ++i){ + mgb_assert(inputs[i]->comp_node() == dest.comp_node, + "%s expects the device of inputs[%d] to be same as the device of handle; " + "got %s and %s actually", rng.dyn_typeinfo()->name, i, + inputs[i]->comp_node().to_string().c_str(), + dest.comp_node.to_string().c_str()); + } + } + dest.layout = _InferLayout::do_infer(inputs[0], rng); return {dest}; } template SmallVector apply_on_physical_tensor( const OpDef& def, const SmallVector& inputs) { - auto desc = infer_output_attrs(def, inputs); SmallVector outputs; + SmallVector desc; + desc = infer_output_attrs(def, inputs); for (auto&& i : desc) { outputs.push_back(Tensor::make(i.layout, i.comp_node)); } @@ -268,51 +436,32 @@ SymbolVar apply_on_var_node( const OpDef& def, const VarNodeArray& inputs) { size_t nr_inp = inputs.size(); + constexpr size_t dnn_nr_inp = OpMeth::DnnOp::NR_INPUTS; auto&& rng = def.cast_final_safe(); - mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", - rng.dyn_typeinfo()->name, - nr_inp); - auto param = OpMeth::make_param(rng); - OperatorNodeConfig config; - if (rng.handle) { - config = {rng.make_name(), RNGDnnOpManager::get_comp_node(rng.handle)}; - } else { - config = {rng.make_name()}; + if(dnn_nr_inp == 0){ + mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", + rng.dyn_typeinfo()->name, + nr_inp); } - return OpMeth::OpNode::make(inputs[0], param, config); + constexpr size_t mgb_nr_inp = dnn_nr_inp + !dnn_nr_inp; + return _RNGOprMaker::make(inputs, rng); } -template +template std::tuple, bool> infer_output_attrs_fallible( const OpDef& def, const SmallVector& inputs) { - auto&& xxx_rng_def = def.cast_final_safe(); + LogicalTensorDesc dest; + auto&& xxx_rng_def = def.cast_final_safe(); size_t nr_inp = inputs.size(); - mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", - xxx_rng_def.dyn_typeinfo()->name, - nr_inp); - - auto&& tshp = inputs[0]; - - TensorLayout out_layout = tshp.layout; - out_layout.dtype = dtype::Float32(); - if (tshp.layout.ndim == 0 || tshp.value.empty()) { - out_layout.ndim = 0; - return {{{out_layout, tshp.comp_node}}, true}; + constexpr bool rng_with_shape = OpMeth::DnnOp::NR_INPUTS == 0; + if (rng_with_shape){ + mgb_assert(nr_inp == 1, "%s expects 1 inputs; got %lu actually", + xxx_rng_def.dyn_typeinfo()->name, + nr_inp); } - mgb_assert( - tshp.layout.ndim == 1, - "target shape of %s expects ndim=1; got ndim=%lu actually", - xxx_rng_def.dyn_typeinfo()->name, - tshp.layout.ndim); - - size_t target_ndim = tshp.layout.shape[0]; - out_layout.ndim = target_ndim; - auto* ptr = tshp.value.ptr(); - for (size_t i = 0; i < target_ndim; ++i) { - out_layout.shape[i] = ptr[i]; - } - - return {{{out_layout, tshp.comp_node}}, true}; + dest.comp_node = inputs[0].comp_node; + dest.layout = _InferLayout::do_infer(inputs[0], xxx_rng_def); + return {{dest}, true}; } } // anonymous namespace @@ -333,6 +482,10 @@ uint64_t get_global_rng_seed() { return RNGDnnOpManager::get_glob_default_seed(); } +CompNode get_rng_handle_compnode(Handle handle){ + return RNGDnnOpManager::get_comp_node(handle); +} + #define REG_RNG_OP(NAME)\ namespace { \ OP_TRAIT_REG(NAME, NAME, OpMeth::OpNode) \ @@ -344,6 +497,11 @@ OP_TRAIT_REG(NAME, NAME, OpMeth::OpNode) \ REG_RNG_OP(UniformRNG) REG_RNG_OP(GaussianRNG) +REG_RNG_OP(GammaRNG) +REG_RNG_OP(PermutationRNG) +REG_RNG_OP(PoissonRNG) +REG_RNG_OP(BetaRNG) +#undef REG_RNG_OP } // namespace mgb::imperative::rng diff --git a/imperative/src/include/megbrain/imperative/ops/rng.h b/imperative/src/include/megbrain/imperative/ops/rng.h index 7f7e5505..0fc28427 100644 --- a/imperative/src/include/megbrain/imperative/ops/rng.h +++ b/imperative/src/include/megbrain/imperative/ops/rng.h @@ -22,5 +22,6 @@ Handle new_handle(CompNode comp_node, uint64_t seed); size_t delete_handle(Handle handle); void set_global_rng_seed(uint64_t seed); uint64_t get_global_rng_seed(); +CompNode get_rng_handle_compnode(Handle handle); } // namespace mgb::imperative::rng diff --git a/imperative/src/test/rng.cpp b/imperative/src/test/rng.cpp index d03b53b8..bb4f8202 100644 --- a/imperative/src/test/rng.cpp +++ b/imperative/src/test/rng.cpp @@ -42,14 +42,72 @@ void check_rng_basic(Args&& ...args) { } } +template +void check_rng_with_input_basic(const CompNode &cn, + const SmallVector &inputs, Args&& ...args) { + Handle h = new_handle(cn, 123); + auto op = Op::make(std::forward(args)..., h); + auto outputs = OpDef::apply_on_physical_tensor(*op, inputs); + ASSERT_TRUE(outputs[0]->layout().eq_shape(inputs[0]->shape())); + ASSERT_TRUE(cn == outputs[0]->comp_node()); + // sync before delete handle + for (auto&& p: outputs) { + p->get_value(); + } + delete_handle(h); +} + +TEST(TestImperative, PoissonRNGBasic) { + REQUIRE_XPU(2); + for (auto&& cn: {CompNode::load("xpu0"), CompNode::load("xpu1")}){ + TensorShape shape{5, 3000}; + HostTensorND lam{cn, shape, dtype::Float32()}; + auto lam_ptr = lam.ptr(); + for( int i = 0; i < 5*3000; ++i) lam_ptr[i] = 2; + SmallVector inputs{Tensor::make(lam)}; + check_rng_with_input_basic(cn, inputs, 123); + } +} + +TEST(TestImperative, BetaRNGBasic) { + REQUIRE_XPU(2); + for (auto&& cn: {CompNode::load("xpu0"), CompNode::load("xpu1")}){ + TensorShape shape{5, 3000}; + HostTensorND alpha{cn, shape, dtype::Float32()}, + beta{cn, shape, dtype::Float32()}; + auto lam_ptr = alpha.ptr(), beta_ptr = beta.ptr(); + for( int i = 0; i < 5*3000; ++i) lam_ptr[i] = 2, beta_ptr[i] = 2; + SmallVector inputs{Tensor::make(alpha), Tensor::make(beta)}; + check_rng_with_input_basic(cn, inputs, 123); + } +} + +TEST(TestImperative, GammaRNGBasic) { + REQUIRE_XPU(2); + for (auto&& cn: {CompNode::load("xpu0"), CompNode::load("xpu1")}){ + TensorShape size{5, 3000}; + HostTensorND shape{cn, size, dtype::Float32()}, + scale{cn, size, dtype::Float32()}; + auto shape_ptr = shape.ptr(), scale_ptr = scale.ptr(); + for( int i = 0; i < 5*3000; ++i) shape_ptr[i] = 2, scale_ptr[i] = 2; + SmallVector inputs{Tensor::make(shape), Tensor::make(scale)}; + check_rng_with_input_basic(cn, inputs, 123); + } +} + TEST(TestImperative, UniformRNGBasic) { REQUIRE_XPU(2); - check_rng_basic(123); + check_rng_basic(123, dtype::Float32()); } TEST(TestImperative, GaussianRNGBasic) { REQUIRE_XPU(2); - check_rng_basic(123, 2.f, 3.f); + check_rng_basic(123, 2.f, 3.f, dtype::Float32()); +} + +TEST(TestImperative, PermutationRNGBasic) { + REQUIRE_XPU(2); + check_rng_basic(123, dtype::Int32()); } // vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index ffcee330..2aeb01fa 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -123,9 +123,13 @@ def UniformRNG: MgbHashableOp<"UniformRNG", [UniformRNGParam]> { let hashFunction = [{ return mgb::hash_pair_combine( mgb::hash($_self.dyn_typeinfo()), - mgb::hash($_self.handle)); + mgb::hash_pair_combine( + mgb::hash($_self.handle), + mgb::hash($_self.dtype.enumv()) + ) + ); }]; - let cmpFunction = [{return $0.handle == $1.handle;}]; + let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}]; } def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { @@ -139,11 +143,70 @@ def GaussianRNG: MgbHashableOp<"GaussianRNG", [GaussianRNGParam]> { mgb::hash($_self.handle), mgb::hash_pair_combine( mgb::hash($_self.mean), - mgb::hash($_self.std)) + mgb::hash_pair_combine( + mgb::hash($_self.std), + mgb::hash($_self.dtype.enumv()) + ) + ) + ) + ); + }]; + let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std && $0.dtype == $1.dtype;}]; +} + +def GammaRNG: MgbHashableOp<"GammaRNG", [GammaRNGParam]> { + let extraArguments = (ins + MgbSizeTAddr:$handle + ); + let hashFunction = [{ + return mgb::hash_pair_combine( + mgb::hash($_self.dyn_typeinfo()), + mgb::hash($_self.handle) + ); + }]; + let cmpFunction = [{return $0.handle == $1.handle;}]; +} + +def PoissonRNG: MgbHashableOp<"PoissonRNG", [PoissonRNGParam]> { + let extraArguments = (ins + MgbSizeTAddr:$handle + ); + let hashFunction = [{ + return mgb::hash_pair_combine( + mgb::hash($_self.dyn_typeinfo()), + mgb::hash($_self.handle) + ); + }]; + let cmpFunction = [{return $0.handle == $1.handle;}]; +} + +def BetaRNG: MgbHashableOp<"BetaRNG", [BetaRNGParam]> { + let extraArguments = (ins + MgbSizeTAddr:$handle + ); + let hashFunction = [{ + return mgb::hash_pair_combine( + mgb::hash($_self.dyn_typeinfo()), + mgb::hash($_self.handle) + ); + }]; + let cmpFunction = [{return $0.handle == $1.handle;}]; +} + +def PermutationRNG: MgbHashableOp<"PermutationRNG", [PermutationRNGParam]> { + let extraArguments = (ins + MgbSizeTAddr:$handle + ); + let hashFunction = [{ + return mgb::hash_pair_combine( + mgb::hash($_self.dyn_typeinfo()), + mgb::hash_pair_combine( + mgb::hash($_self.handle), + mgb::hash($_self.dtype.enumv()) ) ); }]; - let cmpFunction = [{return $0.handle == $1.handle && $0.mean == $1.mean && $0.std == $1.std;}]; + let cmpFunction = [{return $0.handle == $1.handle && $0.dtype == $1.dtype;}]; } def Linspace: MgbHashableOp<"Linspace", [LinspaceParam]> { diff --git a/src/opr/impl/rand.cpp b/src/opr/impl/rand.cpp index 1045a235..34583ea0 100644 --- a/src/opr/impl/rand.cpp +++ b/src/opr/impl/rand.cpp @@ -19,46 +19,21 @@ using namespace mgb; using namespace opr; using namespace intl; -namespace { - - -template -struct RNGName; - -template<> -struct RNGName { - static constexpr const char* name = "uniform_rng"; -}; - -template<> -struct RNGName { - static constexpr const char* name = "gaussian_rng"; -}; - -} // anonymous namespace - -RNGOprBase::RNGOprBase(const OperatorNodeBaseCtorParam &opr, VarNode *shape): - Super(opr) +template +RNGOprBase::RNGOprBase(const OperatorNodeBaseCtorParam &opr, const Param ¶m): + Super(opr),m_param(param) { - add_input({shape}); - add_output(None)->dtype(dtype::Float32()); - cg::add_workspace_output(this); - - // disable dedup - add_equivalence_component>(this); } -RNGOprBase::~RNGOprBase() { -} - -cg::OperatorNodeBase::NodeProp* RNGOprBase::do_make_node_prop() const { - auto prop = Super::do_make_node_prop(); - prop->add_flag(NodeProp::Flag::IMPURE_FUNC); - prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE}); - return prop; +template +UniqPtrWithCN RNGOprBase::create_megdnn_opr() { + auto opr = intl::create_megdnn_opr(comp_node()); + opr->param() = param(); + return opr; } -void RNGOprBase::ensure_megdnn_opr() { +template +void RNGOprBase::ensure_megdnn_opr() { if (!m_dnn_opr || m_dnn_opr.comp_node() != comp_node()) { // activate comp_node for curandCreateGenerator in create_megdnn_opr comp_node().activate(); @@ -66,53 +41,120 @@ void RNGOprBase::ensure_megdnn_opr() { } } -void RNGOprBase::init_output_static_infer_desc() { - using namespace cg::static_infer; - auto &&mgr = owner_graph()->static_infer_manager(); - auto infer_out = [](TensorShape &dest, const InpVal &inp) { - cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value()); - return true; - }; - auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { - ensure_megdnn_opr(); - dest.ndim = 1; - dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( - {inp.val.at(0).shape(), output(0)->dtype()}); - return true; - }; - mgr.register_shape_infer(output(0), - {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_out}); - mgr.register_shape_infer(output(1), - {SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk}); +/* ================= RNG with shape ================= */ +#define _INST_RNG_OPR_WITH_SHAPE(RNGOpr, name) \ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr); \ +cg::OperatorNodeBase::NodeProp* RNGOpr::do_make_node_prop() const { \ + auto prop = Super::do_make_node_prop(); \ + prop->add_flag(NodeProp::Flag::IMPURE_FUNC); \ + prop->reset_dep_type(input(), {NodeProp::DepType::HOST_VALUE}); \ + return prop; \ +} \ +RNGOpr::RNGOpr(VarNode *shape, const Param ¶m, \ + const OperatorNodeConfig &config): \ + Super({shape->owner_graph(), config, (name), {shape}}, param) \ +{ \ + DType dtype = DType::from_enum(param.dtype); \ + add_input({shape}); \ + add_output(None)->dtype(dtype); \ + cg::add_workspace_output(this); \ + add_equivalence_component>(this); \ +} \ +SymbolVar RNGOpr::make(SymbolVar shape, const Param ¶m, \ + const OperatorNodeConfig &config){ \ + return shape.insert_single_output_opr(shape.node(), param, config); \ +} \ +void RNGOpr::init_output_static_infer_desc() { \ + using namespace cg::static_infer; \ + auto &&mgr = owner_graph()->static_infer_manager(); \ + auto infer_out = [](TensorShape &dest, const InpVal &inp) { \ + cg::copy_tensor_value_to_shape(dest, inp.val.at(0).value()); \ + return true; \ + }; \ + auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { \ + ensure_megdnn_opr(); \ + dest.ndim = 1; \ + dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( \ + {inp.val.at(0).shape(), output(0)->dtype()}); \ + return true; \ + }; \ + mgr.register_shape_infer(output(0), \ + {SourceType::DEP, {{input(0), DepType::VALUE}}, infer_out}); \ + mgr.register_shape_infer(output(1), \ + {SourceType::DEP, {{output(0), DepType::SHAPE}}, infer_wk}); \ +} \ +void RNGOpr::scn_do_execute() { \ + m_dnn_opr->exec(output(0)->dev_tensor().as_megdnn(), \ + get_megdnn_workspace_from_var(output(1))); \ } -void RNGOprBase::scn_do_execute() { - m_dnn_opr->exec( - output(0)->dev_tensor().as_megdnn(), - get_megdnn_workspace_from_var(output(1))); -} - -template -RNGOpr::RNGOpr(VarNode *shape, const Param ¶m, - const OperatorNodeConfig &config): - Super({shape->owner_graph(), config, RNGName::name, {shape}}, - shape), - m_param(param) -{ -} - -template -SymbolVar RNGOpr::make(SymbolVar shape, const Param ¶m, - const OperatorNodeConfig &config) { - return shape.insert_single_output_opr(shape.node(), param, config); -} - -template -UniqPtrWithCN RNGOpr::create_megdnn_opr() { - auto opr = intl::create_megdnn_opr(comp_node()); - opr->param() = param(); - return opr; -} +_INST_RNG_OPR_WITH_SHAPE(UniformRNG,"uniform_rng") +_INST_RNG_OPR_WITH_SHAPE(GaussianRNG,"gaussian_rng") +_INST_RNG_OPR_WITH_SHAPE(PermutationRNG,"permutation_rng") +#undef _INST_RNG_OPR_WITH_SHAPE + +/* ================= RNG with input ================= */ +#define _AS_MEGDNN(idx) input((idx))->dev_tensor().as_megdnn() +#define _INFER_WK_DEPS(idx) {input((idx)), DepType::SHAPE} +#define _INFER_WK_ARGS(idx) {inp.val.at((idx)).shape(), input((idx))->dtype()} + +#define _INST_RNG_OPR_WITH_INPUT(RNGOpr, name) \ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr); \ +RNGOpr::RNGOpr(_INPUTS(VarNode*,), const Param ¶m, \ + const OperatorNodeConfig &config): \ + Super({i0->owner_graph(), config, (name), {_INPUTS(,)}}, param) \ +{ \ + add_input({_INPUTS(,)}); \ + add_output(None)->dtype(i0->dtype()); \ + cg::add_workspace_output(this); \ + add_equivalence_component>(this); \ +} \ +SymbolVar RNGOpr::make(_INPUTS(SymbolVar,), const Param ¶m, \ + const OperatorNodeConfig &config){ \ + return i0.insert_single_output_opr(_INPUTS(,.node()), param, config); \ +} \ +void RNGOpr::init_output_static_infer_desc() { \ + using namespace cg::static_infer; \ + auto &&mgr = owner_graph()->static_infer_manager(); \ + auto infer_wk = [this](TensorShape &dest, const InpVal &inp) { \ + ensure_megdnn_opr(); \ + dest.ndim = 1; \ + dest.shape[0] = m_dnn_opr->get_workspace_in_bytes( \ + _FOR_EACH(_INFER_WK_ARGS), \ + {output(0)->shape(), output(0)->dtype()}); \ + return true; \ + }; \ + mgr.register_shape_infer(output(0),ShapeInferDesc::make_identity(input(0))); \ + mgr.register_shape_infer(output(1),{SourceType::DEP, {_FOR_EACH(_INFER_WK_DEPS)}, \ + infer_wk}); \ +} \ +void RNGOpr::add_input_layout_constraint(){ \ + for (auto i : input()) i->add_layout_constraint_contiguous(); \ +}; \ +void RNGOpr::scn_do_execute() { \ + m_dnn_opr->exec(_FOR_EACH(_AS_MEGDNN),output(0)->dev_tensor().as_megdnn(), \ + get_megdnn_workspace_from_var(output(1))); \ +} + +/* ================= 1 input ================= */ +#define _INPUTS(prefix, subfix) prefix i0 subfix +#define _FOR_EACH(cb) cb(0) +_INST_RNG_OPR_WITH_INPUT(PoissonRNG,"poisson_rng") +#undef _INPUTS +#undef _FOR_EACH + +/* ================= 2 input ================= */ +#define _INPUTS(prefix,subfix) prefix i0 subfix, prefix i1 subfix +#define _FOR_EACH(cb) cb(0), cb(1) +_INST_RNG_OPR_WITH_INPUT(BetaRNG,"beta_rng") +_INST_RNG_OPR_WITH_INPUT(GammaRNG,"gamma_rng") +#undef _INPUTS +#undef _FOR_EACH + +#undef _AS_MEGDNN +#undef _INFER_WK_DEPS +#undef _INFER_WK_ARGS +#undef _INST_RNG_OPR_WITH_INPUT #define IMPL(_cls) \ MGB_IMPL_OPR_GRAD(_cls) { \ @@ -123,13 +165,21 @@ UniqPtrWithCN RNGOpr::create_megdnn_opr() { namespace mgb { namespace opr { namespace intl { -template class RNGOpr<::megdnn::GaussianRNG>; -template class RNGOpr<::megdnn::UniformRNG>; +template class RNGOprBase<::megdnn::GaussianRNG>; +template class RNGOprBase<::megdnn::UniformRNG>; +template class RNGOprBase<::megdnn::GammaRNG>; +template class RNGOprBase<::megdnn::PermutationRNG>; +template class RNGOprBase<::megdnn::BetaRNG>; +template class RNGOprBase<::megdnn::PoissonRNG>; #if MGB_ENABLE_GRAD IMPL(GaussianRNG); IMPL(UniformRNG); +IMPL(GammaRNG); +IMPL(PoissonRNG); +IMPL(PermutationRNG); +IMPL(BetaRNG); #endif -} +} } } diff --git a/src/opr/impl/rand.sereg.h b/src/opr/impl/rand.sereg.h index 68869f0b..315b6975 100644 --- a/src/opr/impl/rand.sereg.h +++ b/src/opr/impl/rand.sereg.h @@ -17,6 +17,10 @@ namespace opr { MGB_SEREG_OPR(UniformRNG, 1); MGB_SEREG_OPR(GaussianRNG, 1); + MGB_SEREG_OPR(GammaRNG, 2); + MGB_SEREG_OPR(PoissonRNG, 1); + MGB_SEREG_OPR(PermutationRNG, 1); + MGB_SEREG_OPR(BetaRNG, 2); } // namespace opr } // namespace mgb diff --git a/src/opr/include/megbrain/opr/rand.h b/src/opr/include/megbrain/opr/rand.h index bbf6f05e..57d02248 100644 --- a/src/opr/include/megbrain/opr/rand.h +++ b/src/opr/include/megbrain/opr/rand.h @@ -14,7 +14,6 @@ #include "megbrain/graph.h" #include "megbrain/opr/internal/out_shape_by_sym_var.h" #include "megbrain/opr/internal/megdnn_opr_wrapper.h" - #include "megdnn/oprs.h" namespace mgb { @@ -22,60 +21,81 @@ namespace opr { namespace intl { +template MGB_DEFINE_CLS_WITH_SUPER(RNGOprBase, cg::SingleCNOperatorNodeBase) // { - UniqPtrWithCN m_dnn_opr; - - void ensure_megdnn_opr(); - void init_output_static_infer_desc() override; - void scn_do_execute() override final; - - protected: - RNGOprBase(const OperatorNodeBaseCtorParam &opr, VarNode *shape); - ~RNGOprBase(); - NodeProp* do_make_node_prop() const override; - - virtual UniqPtrWithCN create_megdnn_opr() = 0; -}; - -template -MGB_DEFINE_OPR_CLASS(RNGOpr, RNGOprBase) // { - public: using Param = typename MegDNNOpr::Param; - - RNGOpr(VarNode *shape, const Param ¶m, - const OperatorNodeConfig &config); - - static SymbolVar make(SymbolVar shape, const Param ¶m = {}, - const OperatorNodeConfig &config = {}); - - static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, - const OperatorNodeConfig &config, - const Param ¶m = {}) { - return make(var_from_tensor_shape(graph, config, "rng", shape), - param, config); - } - const Param& param() const { return m_param; } private: Param m_param; - UniqPtrWithCN create_megdnn_opr() override; + UniqPtrWithCN create_megdnn_opr(); + + protected: + ~RNGOprBase(){}; + RNGOprBase(const OperatorNodeBaseCtorParam &opr, const Param ¶m); + void ensure_megdnn_opr(); + UniqPtrWithCN m_dnn_opr; +}; + +/* ================= RNG with shape ================= */ +#define _DEFINE_RNG_OPR_WITH_SHAPE_CLASS(RNG) \ +MGB_DEFINE_OPR_CLASS(RNG,RNGOprBase) \ + cg::OperatorNodeBase::NodeProp* do_make_node_prop() const override; \ + public: \ + RNG(VarNode *shape, const Param ¶m, const OperatorNodeConfig &config); \ + static SymbolVar make(SymbolVar shape, const Param ¶m = {}, \ + const OperatorNodeConfig &config = {}); \ + static SymbolVar make(ComputingGraph &graph, const TensorShape &shape, \ + const OperatorNodeConfig &config, \ + const Param ¶m = {}) { \ + return make(var_from_tensor_shape(graph, config, "rng", shape), \ + param, config); \ + } \ + void init_output_static_infer_desc() override; \ + void scn_do_execute() override; \ }; -#undef _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL -#define _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL template -MGB_DYN_TYPE_OBJ_FINAL_IMPL(RNGOpr); -#undef _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL -#define _MGB_DYN_TYPE_OBJ_FINAL_IMPL_TPL +_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(UniformRNG) +_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(GaussianRNG) +_DEFINE_RNG_OPR_WITH_SHAPE_CLASS(PermutationRNG) +#undef _DEFINE_RNG_OPR_WITH_SHAPE_CLASS + +/* ================= RNG with input ================= */ +#define _DEFINE_RNG_OPR_WITH_INPUT_CLASS(RNG) \ +MGB_DEFINE_OPR_CLASS(RNG, RNGOprBase) \ + void add_input_layout_constraint() override; \ + public: \ + RNG(_INPUTS(VarNode*), const Param ¶m, \ + const OperatorNodeConfig &config); \ + static SymbolVar make(_INPUTS(SymbolVar),const Param ¶m = {}, \ + const OperatorNodeConfig &config = {}); \ + void init_output_static_infer_desc() override; \ + void scn_do_execute() override; \ +}; -} // intl +/* ================= 1 input ================= */ +#define _INPUTS(preifx) preifx i0 +_DEFINE_RNG_OPR_WITH_INPUT_CLASS(PoissonRNG) +#undef _INPUTS -using UniformRNG = intl::RNGOpr; -using GaussianRNG = intl::RNGOpr; +/* ================= 2 input ================= */ +#define _INPUTS(preifx) preifx i0, preifx i1 +_DEFINE_RNG_OPR_WITH_INPUT_CLASS(BetaRNG) +_DEFINE_RNG_OPR_WITH_INPUT_CLASS(GammaRNG) +#undef _INPUTS +#undef _DEFINE_RNG_OPR_WITH_INPUT_CLASS + +} // intl +using UniformRNG = intl::UniformRNG; +using GaussianRNG = intl::GaussianRNG; +using GammaRNG = intl::GammaRNG; +using PermutationRNG = intl::PermutationRNG; +using PoissonRNG = intl::PoissonRNG; +using BetaRNG = intl::BetaRNG; } // namespace opr } // namespace mgb diff --git a/src/opr/test/rand.cpp b/src/opr/test/rand.cpp index c6bd18b6..50c41ea6 100644 --- a/src/opr/test/rand.cpp +++ b/src/opr/test/rand.cpp @@ -19,84 +19,76 @@ using namespace mgb; namespace { - struct BasicStat { - double mean, std, min, max; - - static BasicStat make(const float *ptr, size_t size, - double mean_expect = 0) { - double sum = 0, sum2 = 0, - min = std::numeric_limits::max(), - max = std::numeric_limits::lowest(); - for (size_t i = 0; i < size; ++ i) { - double cur = ptr[i]; - min = std::min(min, cur); - max = std::max(max, cur); - cur -= mean_expect; - sum += cur; - sum2 += cur * cur; - } - - double mean = sum / size + mean_expect, - std = sqrt((sum2 - sum * sum / size) / (size - 1)); - return {mean, std, min, max}; +struct BasicStat { + double mean, std, min, max; + + static BasicStat make(const float* ptr, size_t size, + double mean_expect = 0) { + double sum = 0, sum2 = 0, min = std::numeric_limits::max(), + max = std::numeric_limits::lowest(); + for (size_t i = 0; i < size; ++i) { + double cur = ptr[i]; + min = std::min(min, cur); + max = std::max(max, cur); + cur -= mean_expect; + sum += cur; + sum2 += cur * cur; } - }; - void check_reproducibility( - thin_function make) { - auto graph = ComputingGraph::make(); - constexpr size_t SIZE = 123; - - // out[func][opr][run] - HostTensorND out[2][2][2]; - - auto run = [&](int fid) { - SymbolVar - o0 = make(cg::var_from_tensor_shape(*graph, - {CompNode::load("xpu0")}, "shp0", {SIZE}), 0), - o1 = make(cg::var_from_tensor_shape(*graph, - {CompNode::load("xpu0")}, "shp0", {SIZE}), 1); - HostTensorND host_o0, host_o1; - auto func = graph->compile({ - make_callback_copy(o0, host_o0), - make_callback_copy(o1, host_o1)}); - for (int i = 0; i < 2; ++ i) { - func->execute(); - out[fid][0][i].copy_from(host_o0); - out[fid][1][i].copy_from(host_o1); - } - }; - run(0); - run(1); - - for (int i = 0; i < 2; ++ i) { - for (int j = 0; j < 2; ++ j) - MGB_ASSERT_TENSOR_EQ(out[0][i][j], out[1][i][j]); + double mean = sum / size + mean_expect, + std = sqrt((sum2 - sum * sum / size) / (size - 1)); + return {mean, std, min, max}; + } +}; + +void check_reproducibility(std::shared_ptr graph, size_t size, + thin_function make) { + // out[func][opr][run] + HostTensorND out[2][2][2]; + + auto run = [&](int fid) { + SymbolVar o0 = make(0), o1 = make(1); + HostTensorND host_o0, host_o1; + auto func = graph->compile({make_callback_copy(o0, host_o0), + make_callback_copy(o1, host_o1)}); + for (int i = 0; i < 2; ++i) { + func->execute(); + out[fid][0][i].copy_from(host_o0); + out[fid][1][i].copy_from(host_o1); } + }; + run(0); + run(1); - auto max_diff = [&](int off0, int off1) { - float diff = 0; - auto p0 = out[0][off0 / 2][off0 % 2].ptr(), - p1 = out[0][off1 / 2][off1 % 2].ptr(); - for (size_t i = 0; i < SIZE; ++ i) { - update_max(diff, std::abs(p0[i] - p1[i])); - } - return diff; - }; - - for (int i = 0; i < 4; ++ i) { - for (int j = i + 1; j < 4; ++ j) - ASSERT_GT(max_diff(i, j), 0.3) << i << " " << j; + for (int i = 0; i < 2; ++i) { + for (int j = 0; j < 2; ++j) + MGB_ASSERT_TENSOR_EQ(out[0][i][j], out[1][i][j]); + } + + auto max_diff = [&](int off0, int off1) { + float diff = 0; + auto p0 = out[0][off0 / 2][off0 % 2].ptr(), + p1 = out[0][off1 / 2][off1 % 2].ptr(); + for (size_t i = 0; i < size; ++i) { + update_max(diff, std::abs(p0[i] - p1[i])); } + return diff; + }; + + for (int i = 0; i < 4; ++i) { + for (int j = i + 1; j < 4; ++j) + ASSERT_GT(max_diff(i, j), 0.3) << i << " " << j; } +} -} // anonymous namespace +} // anonymous namespace TEST(TestOprRand, Uniform) { static constexpr size_t M = 128, N = 64; auto graph = ComputingGraph::make(); + SymbolVar dev_out = opr::UniformRNG::make( - *graph, {M, N}, {CompNode::load("xpu0")}); + *graph, {M, N}, {CompNode::load("xpu0")}, {23, DTypeEnum::Float32}); HostTensorND host_out; auto func = graph->compile({make_callback_copy(dev_out, host_out)}); @@ -115,9 +107,10 @@ TEST(TestOprRand, Gaussian) { static constexpr size_t SIZE = 123451; constexpr float MEAN = 1, STD = 2; auto graph = ComputingGraph::make(); + auto y = opr::GaussianRNG::make( SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}), - {23, MEAN, STD}); + {23, MEAN, STD, DTypeEnum::Float32}); HostTensorND host_y; auto func = graph->compile({make_callback_copy(y, host_y)}); @@ -130,17 +123,212 @@ TEST(TestOprRand, Gaussian) { ASSERT_LT(fabs(stat.std - STD), 0.1); } +TEST(TestOprRand, Gamma) { + std::shared_ptr shape_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{2000000*5}, dtype::Float32()}); + std::shared_ptr scale_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{2000000*5}, dtype::Float32()}); + auto shape_ptr = shape_host->ptr(); + auto scale_ptr = scale_host->ptr(); + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 2000000; ++j) { + shape_ptr[i * 2000000 + j] = 2 * 0.3 * i + 0.5; + scale_ptr[i * 2000000 + j] = i * 0.3 + 0.5; + } + } + auto graph = ComputingGraph::make(); + auto shape_sym = opr::Host2DeviceCopy::make(*graph, shape_host); + auto scale_sym = opr::Host2DeviceCopy::make(*graph, scale_host); + auto y = opr::GammaRNG::make(shape_sym, scale_sym, {10}); + + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + + func->execute(); + + ASSERT_EQ(TensorShape({2000000*5}), host_y.shape()); + for (int i = 0; i < 5; ++i) { + float a = 2 * 0.3 * i + 0.5, b = i * 0.3 + 0.5; + float mean = a * b; + float std = a * (b * b); + auto stat = BasicStat::make(host_y.ptr() + 2000000 * i, + 2000000, mean); + ASSERT_LT(fabs(stat.mean - mean), 0.01); + ASSERT_LT(fabs(stat.std - sqrt(std)), 0.01); + } +} + +TEST(TestOprRand, Poisson) { + std::shared_ptr lam_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{200000*5}, dtype::Float32()}); + auto lam_ptr = lam_host->ptr(); + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 200000; ++j) { + lam_ptr[i * 200000 + j] = i + 1; + } + } + auto graph = ComputingGraph::make(); + auto lam_sym = opr::Host2DeviceCopy::make(*graph, lam_host); + auto y = opr::PoissonRNG::make(lam_sym, {10}); + + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + + func->execute(); + + ASSERT_EQ(TensorShape({200000*5}), host_y.shape()); + for (int i = 0; i < 5; ++i) { + float lambda = i + 1; + auto stat = BasicStat::make(host_y.ptr() + 200000 * i, + 200000,lambda); + ASSERT_LT(fabs(stat.mean - lambda), 0.01); + ASSERT_LT(fabs(stat.std - sqrt(lambda)), 0.1); + } +} + +TEST(TestOprRand, Beta) { + std::shared_ptr alpha_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{200000*5}, dtype::Float32()}); + std::shared_ptr beta_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{200000*5}, dtype::Float32()}); + auto alpha_ptr = alpha_host->ptr(); + auto beta_ptr = beta_host->ptr(); + for (int i = 0; i < 5; ++i) { + for (int j = 0; j < 200000; ++j) { + alpha_ptr[i * 200000 + j] = 0.3 * i + 0.1; + beta_ptr[i * 200000 + j] = 2 * i * 0.3 + 0.1; + } + } + auto graph = ComputingGraph::make(); + auto alpha_sym = opr::Host2DeviceCopy::make(*graph, alpha_host); + auto beta_sym = opr::Host2DeviceCopy::make(*graph, beta_host); + auto y = opr::BetaRNG::make(alpha_sym,beta_sym, {10}); + + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + + func->execute(); + + ASSERT_EQ(TensorShape({200000*5}), host_y.shape()); + for (int i = 0; i < 5; ++i) { + float a = 0.3 * i + 0.1, b = 2 * i * 0.3 + 0.1; + float mean = a / (a + b); + float std = a * b / ((a + b) * (a + b) * (a + b + 1)); + auto stat = BasicStat::make(host_y.ptr() + 200000 * i, + 200000, mean); + ASSERT_LT(fabs(stat.mean - mean), 0.01); + ASSERT_LT(fabs(stat.std - sqrt(std)), 0.01); + } +} + +TEST(TestOprRand, PermutationRNG) { + static constexpr size_t SIZE = 123451; + auto graph = ComputingGraph::make(); + auto y = opr::PermutationRNG::make( + SymbolVar::make_scalar(int(SIZE), *graph, {CompNode::load("xpu0")}), + {23, DTypeEnum::Int32}); + HostTensorND host_y; + auto func = graph->compile({make_callback_copy(y, host_y)}); + + func->execute(); + + ASSERT_EQ(TensorShape({SIZE}), host_y.shape()); + auto ptr = host_y.ptr(); + std::vector res(SIZE); + int not_same = 0; + for (size_t i = 0; i < SIZE; ++i) { + if ((ptr[i] - int32_t(i)) >= 1) not_same++; + res[i] = ptr[i]; + } + ASSERT_GT(not_same, 5000); + std::sort(res.begin(), res.end()); + for (size_t i = 0; i < SIZE; ++i) { + ASSERT_LE(std::abs(res[i] - int32_t(i)), 1e-8); + } +} + TEST(TestOprRand, UniformReprod) { - check_reproducibility([](SymbolVar shp, uint64_t seed) { + static constexpr size_t SIZE = 123; + auto graph = ComputingGraph::make(); + auto shp = cg::var_from_tensor_shape(*graph, {CompNode::load("xpu0")}, + "shp0", {SIZE}); + check_reproducibility(graph, SIZE, [&shp](uint64_t seed) { return opr::UniformRNG::make(shp, {seed}); }); } TEST(TestOprRand, GaussianReprod) { - check_reproducibility([](SymbolVar shp, uint64_t seed) { + static constexpr size_t SIZE = 123; + auto graph = ComputingGraph::make(); + auto shp = cg::var_from_tensor_shape(*graph, {CompNode::load("xpu0")}, + "shp0", {SIZE}); + check_reproducibility(graph, SIZE, [&shp](uint64_t seed) { return opr::GaussianRNG::make(shp, {seed}); }); } -// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} +TEST(TestOprRand, GammaReprod) { + static constexpr size_t SIZE = 123; + std::shared_ptr shape_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); + std::shared_ptr scale_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); + auto shape_ptr = shape_host->ptr(); + auto scale_ptr = scale_host->ptr(); + for (size_t i = 0; i < SIZE; ++i){ + shape_ptr[i] = 0.5; + scale_ptr[i] = 1.2; + } + auto graph = ComputingGraph::make(); + auto shape_sym = opr::Host2DeviceCopy::make(*graph, shape_host); + auto scale_sym = opr::Host2DeviceCopy::make(*graph, scale_host); + check_reproducibility(graph, SIZE, [&shape_sym,&scale_sym](uint64_t seed) { + return opr::GammaRNG::make(shape_sym, scale_sym, {seed}); + }); +} + +TEST(TestOprRand, PoissonReprod) { + static constexpr size_t SIZE = 123; + std::shared_ptr lam_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); + auto lam_ptr = lam_host->ptr(); + for (size_t i = 0; i < SIZE; ++i) + lam_ptr[i] = 2; + auto graph = ComputingGraph::make(); + auto lam_sym = opr::Host2DeviceCopy::make(*graph, lam_host); + check_reproducibility(graph, SIZE, [&lam_sym](uint64_t seed) { + return opr::PoissonRNG::make(lam_sym, {seed}); + }); +} +TEST(TestOprRand, BetaReprod) { + static constexpr size_t SIZE = 123; + std::shared_ptr alpha_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); + std::shared_ptr beta_host(new HostTensorND{ + CompNode::load("xpux"), TensorShape{SIZE}, dtype::Float32()}); + auto alpha_ptr = alpha_host->ptr(); + auto beta_ptr = beta_host->ptr(); + for (size_t i = 0; i < SIZE; ++i){ + alpha_ptr[i] = 0.5; + beta_ptr[i] = 1.2; + } + auto graph = ComputingGraph::make(); + auto alpha_sym = opr::Host2DeviceCopy::make(*graph, alpha_host); + auto beta_sym = opr::Host2DeviceCopy::make(*graph, beta_host); + check_reproducibility(graph, SIZE, [&alpha_sym,&beta_sym](uint64_t seed) { + return opr::BetaRNG::make(alpha_sym, beta_sym, {seed}); + }); +} + +TEST(TestOprRand, PermutationReprod) { + static constexpr size_t SIZE = 123; + auto graph = ComputingGraph::make(); + auto shp = cg::var_from_tensor_shape(*graph, {CompNode::load("xpu0")}, + "shp0", {SIZE}); + check_reproducibility(graph, SIZE, [&shp](uint64_t seed) { + return opr::PermutationRNG::make(shp, {seed, DTypeEnum::Float32}); + }); +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index 2a9c3e92..5a31493a 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -108,6 +108,10 @@ union OperatorParam { param.TQT = 74, param.Correlation = 75, param.LSQ = 76, + param.GammaRNG = 77, + param.PoissonRNG = 78, + param.PermutationRNG = 79, + param.BetaRNG = 80, } table Operator { -- GitLab