From a28eadcab84b7291bc18099538cae11794c8d872 Mon Sep 17 00:00:00 2001 From: Liang Depeng Date: Wed, 7 Jul 2021 19:41:54 +0800 Subject: [PATCH] add bernoulli module (#5353) * add bernoulli module * fix doc test * add bernoulli functor * make changes according to review * refine (#5415) * fix Co-authored-by: Bowen Chen Co-authored-by: oneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com> --- docs/source/experimental.rst | 1 + oneflow/core/functional/functional_api.yaml | 5 ++ .../core/functional/impl/random_functor.cpp | 72 +++++++++++++++++++ oneflow/python/nn/modules/random_ops.py | 67 +++++++++++++++++ oneflow/python/test/modules/test_bernoulli.py | 48 +++++++++++++ oneflow/user/kernels/bernoulli_kernel.cpp | 17 +++-- oneflow/user/kernels/bernoulli_kernel.h | 37 ++++++++++ 7 files changed, 243 insertions(+), 4 deletions(-) create mode 100644 oneflow/core/functional/impl/random_functor.cpp create mode 100644 oneflow/python/nn/modules/random_ops.py create mode 100644 oneflow/python/test/modules/test_bernoulli.py create mode 100644 oneflow/user/kernels/bernoulli_kernel.h diff --git a/docs/source/experimental.rst b/docs/source/experimental.rst index 3cabd4803c..5bfb364007 100644 --- a/docs/source/experimental.rst +++ b/docs/source/experimental.rst @@ -237,3 +237,4 @@ Experimental features .. autofunction:: oneflow.experimental.tensor_to_tensor_buffer .. autofunction:: oneflow.experimental.Tensor.type_as .. autofunction:: oneflow.experimental.Tensor.long +.. autofunction:: oneflow.experimental.bernoulli diff --git a/oneflow/core/functional/functional_api.yaml b/oneflow/core/functional/functional_api.yaml index 74eb602a24..273b2635b2 100644 --- a/oneflow/core/functional/functional_api.yaml +++ b/oneflow/core/functional/functional_api.yaml @@ -59,6 +59,11 @@ signature: "Tensor ScalarMulByTensor(Tensor x, Tensor scalar)" bind_python: True +- name: "bernoulli" + signature: + "Tensor Bernoulli(Tensor x, DataType dtype=kFloat, Generator generator=None)" + bind_python: True + - name: "broadcast_mul" signature: "Tensor BroadcastMul(Tensor x, Tensor y)" bind_python: True diff --git a/oneflow/core/functional/impl/random_functor.cpp b/oneflow/core/functional/impl/random_functor.cpp new file mode 100644 index 0000000000..f630bd28c6 --- /dev/null +++ b/oneflow/core/functional/impl/random_functor.cpp @@ -0,0 +1,72 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#include "oneflow/core/common/optional.h" +#include "oneflow/core/framework/attr_map.h" +#include "oneflow/core/framework/op_builder.h" +#include "oneflow/core/framework/op_expr.h" +#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h" +#include "oneflow/core/framework/tensor.h" +#include "oneflow/core/framework/tensor_tuple.h" +#include "oneflow/core/framework/op_interpreter.h" +#include "oneflow/core/framework/random_generator.h" +#include "oneflow/core/functional/function_library.h" +#include "oneflow/core/functional/impl/common.h" +#include "oneflow/core/functional/impl/unary_functor.h" +#include "oneflow/user/kernels/bernoulli_kernel.h" + +namespace oneflow { +namespace one { +namespace functional { + +namespace impl { + +class BernoulliFunctor { + public: + BernoulliFunctor() { + bernoulli_op_ = CHECK_JUST(one::OpBuilder("bernoulli").Input("in").Output("out").Build()); + } + Maybe operator()(const std::shared_ptr& x, const DataType& dtype, + const Optional& generator) const { + MutableAttrMap bernoulli_attrs; + JUST(bernoulli_attrs.SetAttr("dtype", dtype)); + + std::shared_ptr gen; + if (!generator) { + gen = JUST(one::DefaultAutoGenerator()); + } else { + gen = JUST(generator.value()); + } + + JUST(bernoulli_attrs.SetAttr("seed", gen->current_seed())); + + const auto& bernoulli_kernel_state = std::make_shared(gen); + + return OpInterpUtil::Dispatch( + *bernoulli_op_, {x}, + OpExprInterpContext{.attrs = bernoulli_attrs, .state = bernoulli_kernel_state}); + } + + private: + std::shared_ptr bernoulli_op_; +}; + +} // namespace impl + +ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor("Bernoulli"); }; + +} // namespace functional +} // namespace one +} // namespace oneflow diff --git a/oneflow/python/nn/modules/random_ops.py b/oneflow/python/nn/modules/random_ops.py new file mode 100644 index 0000000000..884b235cd7 --- /dev/null +++ b/oneflow/python/nn/modules/random_ops.py @@ -0,0 +1,67 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import sys +import random +import oneflow as flow +from oneflow.python.nn.module import Module +from oneflow.python.oneflow_export import oneflow_export, experimental_api + + +@oneflow_export("bernoulli") +@experimental_api +def bernoulli(input, *, generator=None, out=None): + r"""This operator returns a Tensor with binaray random numbers (0 / 1) from a Bernoulli distribution. + + Args: + input(Tensor) - the input tensor of probability values for the Bernoulli distribution + generator: (optional) – a pseudorandom number generator for sampling + out (Tensor, optional) – the output tensor. + + Shape: + - Input: :math:`(*)`. Input can be of any shape + - Output: :math:`(*)`. Output is of the same shape as input + + For example: + + .. code-block:: python + + >>> import numpy as np + >>> import oneflow.experimental as flow + >>> flow.enable_eager_execution() + + >>> arr = np.array( + ... [ + ... [1.0, 1.0, 1.0], + ... [1.0, 1.0, 1.0], + ... [1.0, 1.0, 1.0], + ... ] + ... ) + >>> x = flow.Tensor(arr) + >>> y = flow.bernoulli(x) + >>> y + tensor([[1., 1., 1.], + [1., 1., 1.], + [1., 1., 1.]], dtype=oneflow.float32) + + + """ + return flow.F.bernoulli(input, flow.float32, generator) + + +if __name__ == "__main__": + import doctest + + doctest.testmod(raise_on_error=True) diff --git a/oneflow/python/test/modules/test_bernoulli.py b/oneflow/python/test/modules/test_bernoulli.py new file mode 100644 index 0000000000..8b80cfa50c --- /dev/null +++ b/oneflow/python/test/modules/test_bernoulli.py @@ -0,0 +1,48 @@ +""" +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" +import unittest +from collections import OrderedDict + +import numpy as np + +import oneflow.experimental as flow +from test_util import GenArgList + + +def _test_bernoulli(test_case, shape): + input_arr = np.ones(shape) + x = flow.Tensor(input_arr, device=flow.device("cpu")) + y = flow.bernoulli(x) + test_case.assertTrue(np.allclose(y.numpy(), x.numpy())) + + +@unittest.skipIf( + not flow.unittest.env.eager_execution_enabled(), + ".numpy() doesn't work in lazy mode", +) +class TestBernoulli(flow.unittest.TestCase): + def test_bernoulli(test_case): + arg_dict = OrderedDict() + arg_dict["test_functions"] = [ + _test_bernoulli, + ] + arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)] + for arg in GenArgList(arg_dict): + arg[0](test_case, *arg[1:]) + + +if __name__ == "__main__": + unittest.main() diff --git a/oneflow/user/kernels/bernoulli_kernel.cpp b/oneflow/user/kernels/bernoulli_kernel.cpp index 222aa058f9..00b98c9372 100644 --- a/oneflow/user/kernels/bernoulli_kernel.cpp +++ b/oneflow/user/kernels/bernoulli_kernel.cpp @@ -14,8 +14,10 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "oneflow/core/framework/framework.h" +#include "oneflow/user/kernels/bernoulli_kernel.h" #include "oneflow/user/kernels/op_kernel_state_wrapper.h" #include "oneflow/user/kernels/random_seed_util.h" +#include "oneflow/user/kernels/random_mask_generator.h" namespace oneflow { @@ -27,13 +29,13 @@ class BernoulliKerenl final : public user_op::OpKernel { std::shared_ptr CreateOpKernelState( user_op::KernelInitContext* ctx) const override { - int64_t seed = GetOpKernelRandomSeed(ctx); - return std::make_shared>(seed); + const auto& generator = CHECK_JUST(one::MakeAutoGenerator()); + generator->set_current_seed(ctx->Attr("seed")); + return std::make_shared(generator); } private: void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override { - auto* random_generator = dynamic_cast*>(state); user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0); user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0); const T* in_dptr = in_blob->dptr(); @@ -41,11 +43,18 @@ class BernoulliKerenl final : public user_op::OpKernel { CHECK_EQ(GetDataType(), in_blob->data_type()); CHECK_EQ(GetDataType(), out_blob->data_type()); CHECK_EQ(in_blob->shape().elem_cnt(), out_blob->shape().elem_cnt()); + + auto* bernoulli_kernel_state = dynamic_cast(state); + CHECK_NOTNULL(bernoulli_kernel_state); + const auto& generator = bernoulli_kernel_state->generator(); + CHECK_NOTNULL(generator); + const auto& cpu_generator = CHECK_JUST(generator->Get()); + for (int32_t i = 0; i < out_blob->shape().elem_cnt(); ++i) { double prob = static_cast(*(in_dptr + i)); CHECK(prob >= 0.0 && prob <= 1.0); std::bernoulli_distribution dis(prob); - *(out_dptr + i) = dis(*random_generator->Mutable()) ? GetOneVal() : GetZeroVal(); + *(out_dptr + i) = dis(cpu_generator->engine()) ? GetOneVal() : GetZeroVal(); } } bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } diff --git a/oneflow/user/kernels/bernoulli_kernel.h b/oneflow/user/kernels/bernoulli_kernel.h new file mode 100644 index 0000000000..657f298c23 --- /dev/null +++ b/oneflow/user/kernels/bernoulli_kernel.h @@ -0,0 +1,37 @@ +/* +Copyright 2020 The OneFlow Authors. All rights reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ +#ifndef ONEFLOW_USER_KERNEL_BERNOULLI_KERNEL_H_ +#define ONEFLOW_USER_KERNEL_BERNOULLI_KERNEL_H_ + +#include "oneflow/user/kernels/random_mask_generator.h" +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +class BernoulliKernelState : public user_op::OpKernelState { + public: + explicit BernoulliKernelState(const std::shared_ptr& generator) + : generator_(generator) {} + + const std::shared_ptr& generator() const { return generator_; } + + private: + std::shared_ptr generator_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_USER_KERNEL_BERNOULLI_KERNEL_H_ -- GitLab