未验证 提交 a28eadca 编写于 作者: L Liang Depeng 提交者: GitHub

add bernoulli module (#5353)

* add bernoulli module

* fix doc test

* add bernoulli functor

* make changes according to review

* refine (#5415)

* fix
Co-authored-by: NBowen Chen <bob2420083992@gmail.com>
Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 41786ad7
......@@ -237,3 +237,4 @@ Experimental features
.. autofunction:: oneflow.experimental.tensor_to_tensor_buffer
.. autofunction:: oneflow.experimental.Tensor.type_as
.. autofunction:: oneflow.experimental.Tensor.long
.. autofunction:: oneflow.experimental.bernoulli
......@@ -59,6 +59,11 @@
signature: "Tensor ScalarMulByTensor(Tensor x, Tensor scalar)"
bind_python: True
- name: "bernoulli"
signature:
"Tensor Bernoulli(Tensor x, DataType dtype=kFloat, Generator generator=None)"
bind_python: True
- name: "broadcast_mul"
signature: "Tensor BroadcastMul(Tensor x, Tensor y)"
bind_python: True
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/optional.h"
#include "oneflow/core/framework/attr_map.h"
#include "oneflow/core/framework/op_builder.h"
#include "oneflow/core/framework/op_expr.h"
#include "oneflow/core/framework/op_interpreter/op_interpreter_util.h"
#include "oneflow/core/framework/tensor.h"
#include "oneflow/core/framework/tensor_tuple.h"
#include "oneflow/core/framework/op_interpreter.h"
#include "oneflow/core/framework/random_generator.h"
#include "oneflow/core/functional/function_library.h"
#include "oneflow/core/functional/impl/common.h"
#include "oneflow/core/functional/impl/unary_functor.h"
#include "oneflow/user/kernels/bernoulli_kernel.h"
namespace oneflow {
namespace one {
namespace functional {
namespace impl {
class BernoulliFunctor {
public:
BernoulliFunctor() {
bernoulli_op_ = CHECK_JUST(one::OpBuilder("bernoulli").Input("in").Output("out").Build());
}
Maybe<Tensor> operator()(const std::shared_ptr<one::Tensor>& x, const DataType& dtype,
const Optional<one::Generator>& generator) const {
MutableAttrMap bernoulli_attrs;
JUST(bernoulli_attrs.SetAttr<DataType>("dtype", dtype));
std::shared_ptr<one::Generator> gen;
if (!generator) {
gen = JUST(one::DefaultAutoGenerator());
} else {
gen = JUST(generator.value());
}
JUST(bernoulli_attrs.SetAttr<int64_t>("seed", gen->current_seed()));
const auto& bernoulli_kernel_state = std::make_shared<BernoulliKernelState>(gen);
return OpInterpUtil::Dispatch<Tensor>(
*bernoulli_op_, {x},
OpExprInterpContext{.attrs = bernoulli_attrs, .state = bernoulli_kernel_state});
}
private:
std::shared_ptr<OpExpr> bernoulli_op_;
};
} // namespace impl
ONEFLOW_FUNCTION_LIBRARY(m) { m.add_functor<impl::BernoulliFunctor>("Bernoulli"); };
} // namespace functional
} // namespace one
} // namespace oneflow
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import sys
import random
import oneflow as flow
from oneflow.python.nn.module import Module
from oneflow.python.oneflow_export import oneflow_export, experimental_api
@oneflow_export("bernoulli")
@experimental_api
def bernoulli(input, *, generator=None, out=None):
r"""This operator returns a Tensor with binaray random numbers (0 / 1) from a Bernoulli distribution.
Args:
input(Tensor) - the input tensor of probability values for the Bernoulli distribution
generator: (optional) – a pseudorandom number generator for sampling
out (Tensor, optional) – the output tensor.
Shape:
- Input: :math:`(*)`. Input can be of any shape
- Output: :math:`(*)`. Output is of the same shape as input
For example:
.. code-block:: python
>>> import numpy as np
>>> import oneflow.experimental as flow
>>> flow.enable_eager_execution()
>>> arr = np.array(
... [
... [1.0, 1.0, 1.0],
... [1.0, 1.0, 1.0],
... [1.0, 1.0, 1.0],
... ]
... )
>>> x = flow.Tensor(arr)
>>> y = flow.bernoulli(x)
>>> y
tensor([[1., 1., 1.],
[1., 1., 1.],
[1., 1., 1.]], dtype=oneflow.float32)
"""
return flow.F.bernoulli(input, flow.float32, generator)
if __name__ == "__main__":
import doctest
doctest.testmod(raise_on_error=True)
"""
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import unittest
from collections import OrderedDict
import numpy as np
import oneflow.experimental as flow
from test_util import GenArgList
def _test_bernoulli(test_case, shape):
input_arr = np.ones(shape)
x = flow.Tensor(input_arr, device=flow.device("cpu"))
y = flow.bernoulli(x)
test_case.assertTrue(np.allclose(y.numpy(), x.numpy()))
@unittest.skipIf(
not flow.unittest.env.eager_execution_enabled(),
".numpy() doesn't work in lazy mode",
)
class TestBernoulli(flow.unittest.TestCase):
def test_bernoulli(test_case):
arg_dict = OrderedDict()
arg_dict["test_functions"] = [
_test_bernoulli,
]
arg_dict["shape"] = [(2, 3), (2, 3, 4), (2, 3, 4, 5)]
for arg in GenArgList(arg_dict):
arg[0](test_case, *arg[1:])
if __name__ == "__main__":
unittest.main()
......@@ -14,8 +14,10 @@ See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/framework/framework.h"
#include "oneflow/user/kernels/bernoulli_kernel.h"
#include "oneflow/user/kernels/op_kernel_state_wrapper.h"
#include "oneflow/user/kernels/random_seed_util.h"
#include "oneflow/user/kernels/random_mask_generator.h"
namespace oneflow {
......@@ -27,13 +29,13 @@ class BernoulliKerenl final : public user_op::OpKernel {
std::shared_ptr<user_op::OpKernelState> CreateOpKernelState(
user_op::KernelInitContext* ctx) const override {
int64_t seed = GetOpKernelRandomSeed(ctx);
return std::make_shared<OpKernelStateWrapper<std::mt19937>>(seed);
const auto& generator = CHECK_JUST(one::MakeAutoGenerator());
generator->set_current_seed(ctx->Attr<int64_t>("seed"));
return std::make_shared<BernoulliKernelState>(generator);
}
private:
void Compute(user_op::KernelComputeContext* ctx, user_op::OpKernelState* state) const override {
auto* random_generator = dynamic_cast<OpKernelStateWrapper<std::mt19937>*>(state);
user_op::Tensor* in_blob = ctx->Tensor4ArgNameAndIndex("in", 0);
user_op::Tensor* out_blob = ctx->Tensor4ArgNameAndIndex("out", 0);
const T* in_dptr = in_blob->dptr<T>();
......@@ -41,11 +43,18 @@ class BernoulliKerenl final : public user_op::OpKernel {
CHECK_EQ(GetDataType<T>(), in_blob->data_type());
CHECK_EQ(GetDataType<K>(), out_blob->data_type());
CHECK_EQ(in_blob->shape().elem_cnt(), out_blob->shape().elem_cnt());
auto* bernoulli_kernel_state = dynamic_cast<BernoulliKernelState*>(state);
CHECK_NOTNULL(bernoulli_kernel_state);
const auto& generator = bernoulli_kernel_state->generator();
CHECK_NOTNULL(generator);
const auto& cpu_generator = CHECK_JUST(generator->Get<one::CPUGeneratorImpl>());
for (int32_t i = 0; i < out_blob->shape().elem_cnt(); ++i) {
double prob = static_cast<double>(*(in_dptr + i));
CHECK(prob >= 0.0 && prob <= 1.0);
std::bernoulli_distribution dis(prob);
*(out_dptr + i) = dis(*random_generator->Mutable()) ? GetOneVal<K>() : GetZeroVal<K>();
*(out_dptr + i) = dis(cpu_generator->engine()) ? GetOneVal<K>() : GetZeroVal<K>();
}
}
bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; }
......
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_USER_KERNEL_BERNOULLI_KERNEL_H_
#define ONEFLOW_USER_KERNEL_BERNOULLI_KERNEL_H_
#include "oneflow/user/kernels/random_mask_generator.h"
#include "oneflow/core/framework/framework.h"
namespace oneflow {
class BernoulliKernelState : public user_op::OpKernelState {
public:
explicit BernoulliKernelState(const std::shared_ptr<one::Generator>& generator)
: generator_(generator) {}
const std::shared_ptr<one::Generator>& generator() const { return generator_; }
private:
std::shared_ptr<one::Generator> generator_;
};
} // namespace oneflow
#endif // ONEFLOW_USER_KERNEL_BERNOULLI_KERNEL_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册