From aa2a9b5d89cfd01780feeaca3271a142bb16f5e0 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 21 Aug 2020 23:44:33 +0800 Subject: [PATCH] add bernoulli op (#26511) * add bernoulli op * fix cuda kernel and add unit test * refine doc * fix uniform --- paddle/fluid/operators/bernoulli_op.cc | 88 +++++++++++++++++++ paddle/fluid/operators/bernoulli_op.cu | 72 +++++++++++++++ paddle/fluid/operators/bernoulli_op.h | 39 ++++++++ python/paddle/__init__.py | 1 + .../tests/unittests/test_bernoulli_op.py | 76 ++++++++++++++++ python/paddle/tensor/random.py | 54 ++++++++++++ 6 files changed, 330 insertions(+) create mode 100644 paddle/fluid/operators/bernoulli_op.cc create mode 100644 paddle/fluid/operators/bernoulli_op.cu create mode 100644 paddle/fluid/operators/bernoulli_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_bernoulli_op.py diff --git a/paddle/fluid/operators/bernoulli_op.cc b/paddle/fluid/operators/bernoulli_op.cc new file mode 100644 index 0000000000..c525da5953 --- /dev/null +++ b/paddle/fluid/operators/bernoulli_op.cc @@ -0,0 +1,88 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include "paddle/fluid/operators/bernoulli_op.h" + +#include +#include + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/common_infer_shape_functions.h" + +namespace paddle { +namespace operators { + +class BernoulliOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "A tensor with probabilities for generating the random binary " + "number"); + AddOutput("Out", "A Tensor filled with random binary number"); + AddComment(R"DOC( +This OP returns a Tensor filled with random binary(0 or 1) number from a Bernoulli distribution. + + Out ~ Bernoulli(X) + +)DOC"); + } +}; + +class BernoulliOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + return UnaryOpUnchangedInferShape(ctx); + } +}; + +// It seems that Eigen::Tensor::random in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class BernoulliOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + auto *in_data = x->data(); + auto *out_data = out->mutable_data(ctx.GetPlace()); + + int64_t size = x->numel(); + std::uniform_real_distribution dist(0.0, 1.0); + auto gen_ptr = framework::Generator::GetInstance(); + std::mt19937_64 &gen_engine = gen_ptr->GetCPUEngine(); + + for (int64_t i = 0; i < size; ++i) { + out_data[i] = BernoulliFunctor(in_data[i], dist(gen_engine)); + } + } +}; // namespace operators + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OPERATOR( + bernoulli, ops::BernoulliOp, ops::BernoulliOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(bernoulli, + ops::BernoulliOpKernel, + ops::BernoulliOpKernel); diff --git a/paddle/fluid/operators/bernoulli_op.cu b/paddle/fluid/operators/bernoulli_op.cu new file mode 100644 index 0000000000..d0837071d4 --- /dev/null +++ b/paddle/fluid/operators/bernoulli_op.cu @@ -0,0 +1,72 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/bernoulli_op.h" +#include "paddle/fluid/platform/transform.h" + +namespace paddle { +namespace operators { +// it can be consistent with cpu when CUDAGenerator is provided. +template +struct BernoulliCudaFunctor { + unsigned int seed_; + __host__ __device__ BernoulliCudaFunctor(int seed) : seed_(seed) {} + + __host__ __device__ T operator()(const unsigned int n, const T p) const { + thrust::minstd_rand rng; + rng.seed(seed_); + thrust::uniform_real_distribution dist(0.0, 1.0); + rng.discard(n); + return static_cast(dist(rng) < p); + } +}; + +template +class BernoulliOpKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + std::random_device rd; + auto seed = rd(); + const auto x = ctx.Input("X"); + auto out = ctx.Output("Out"); + auto* in_data = x->data(); + auto* out_data = out->mutable_data(ctx.GetPlace()); + + int64_t size = x->numel(); + thrust::counting_iterator index_sequence_begin(0); + platform::Transform trans; + auto* context = + static_cast(&ctx.device_context()); + trans(*context, index_sequence_begin, index_sequence_begin + size, in_data, + out_data, BernoulliCudaFunctor(seed)); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL( + bernoulli, ops::BernoulliOpKernel, + ops::BernoulliOpKernel); diff --git a/paddle/fluid/operators/bernoulli_op.h b/paddle/fluid/operators/bernoulli_op.h new file mode 100644 index 0000000000..06a83ada17 --- /dev/null +++ b/paddle/fluid/operators/bernoulli_op.h @@ -0,0 +1,39 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/fluid/platform/enforce.h" +#include "paddle/fluid/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +/** + * Samples a bernoulli distribution given a probability input + */ + +template +inline HOSTDEVICE T BernoulliFunctor(T p, T rand) { + PADDLE_ENFORCE_LE(p, 1, platform::errors::OutOfRange( + "The probability should be <= 1, but got %f", p)); + PADDLE_ENFORCE_GE(p, 0, platform::errors::OutOfRange( + "The probability should be >= 1, but got %f", p)); + return static_cast(rand < p); +} + +template +class BernoulliOpKernel; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index aeabe30b59..8153589e0c 100644 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -53,6 +53,7 @@ import paddle.incubate.complex as complex # TODO: define alias in tensor and framework directory from .tensor.random import randperm +from .tensor.random import bernoulli from .tensor.attribute import rank #DEFINE_ALIAS from .tensor.attribute import shape #DEFINE_ALIAS diff --git a/python/paddle/fluid/tests/unittests/test_bernoulli_op.py b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py new file mode 100644 index 0000000000..12a29de804 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py @@ -0,0 +1,76 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +from op_test import OpTest +import numpy as np + + +def output_hist(out): + hist, _ = np.histogram(out, bins=2) + hist = hist.astype("float32") + hist /= float(out.size) + prob = 0.5 * np.ones((2)) + return hist, prob + + +class TestBernoulliOp(OpTest): + def setUp(self): + self.op_type = "bernoulli" + self.inputs = {"X": np.random.uniform(size=(1000, 784))} + self.init_attrs() + self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} + + def init_attrs(self): + self.attrs = {} + self.output_hist = output_hist + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def verify_output(self, outs): + hist, prob = self.output_hist(np.array(outs[0])) + self.assertTrue( + np.allclose( + hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) + + +class TestBernoulliApi(unittest.TestCase): + def test_dygraph(self): + paddle.disable_static() + x = paddle.rand([1024, 1024]) + out = paddle.bernoulli(x) + paddle.enable_static() + hist, prob = output_hist(out.numpy()) + self.assertTrue( + np.allclose( + hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) + + def test_static(self): + x = paddle.rand([1024, 1024]) + out = paddle.bernoulli(x) + exe = paddle.static.Executor(paddle.CPUPlace()) + out = exe.run(paddle.static.default_main_program(), + fetch_list=[out.name]) + hist, prob = output_hist(out[0]) + self.assertTrue( + np.allclose( + hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index 353a22f87f..6b8986004d 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -27,6 +27,7 @@ from ..fluid.layers.tensor import fill_constant from ..fluid.io import shuffle #DEFINE_ALIAS __all__ = [ + 'bernoulli', # 'gaussin', 'uniform', 'shuffle', @@ -37,6 +38,59 @@ __all__ = [ ] +def bernoulli(x, name=None): + """ + + This OP returns a Tensor filled with random binary(0 or 1) number from a Bernoulli distribution. + The input ``x`` is a tensor with probabilities for generating the random binary number. + Each element in ``x`` should be in [0, 1], and the out is generated by: + + .. math:: + + out_i ~ Bernoulli (x_i) + + Args: + x(Tensor): A tensor with probabilities for generating the random binary number. The data type + should be float32, float64. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + Returns: + Tensor: A Tensor filled with random binary number with the same shape and dtype as ``x``. + + Examples: + .. code-block:: python + + import paddle + import numpy as np + + paddle.disable_static() + + x = paddle.rand([2, 3]) + print(x.numpy()) + # [[0.11272584 0.3890902 0.7730957 ] + # [0.10351662 0.8510418 0.63806665]] + + out = paddle.bernoulli(x) + print(out.numpy()) + # [[0. 0. 1.] + # [0. 0. 1.]] + + """ + + if in_dygraph_mode(): + return core.ops.bernoulli(x) + + check_variable_and_dtype(x, "x", ["float32", "float64"], "bernoulli") + + helper = LayerHelper("randint", **locals()) + out = helper.create_variable_for_type_inference( + dtype=x.dtype) # maybe set out to int32 ? + helper.append_op( + type='bernoulli', inputs={"X": x}, outputs={'Out': out}, attrs={}) + return out + + def uniform(shape, dtype='float32', min=-1.0, max=1.0, seed=0, name=None): """ This OP returns a Tensor filled with random values sampled from a uniform -- GitLab