From bcf86e5ce7b18050ffc469b93b69da469862bfd5 Mon Sep 17 00:00:00 2001 From: Zhou Wei <1183042833@qq.com> Date: Fri, 24 Dec 2021 11:50:34 +0800 Subject: [PATCH] add new API/OP: paddle.poisson (#38117) * add new API/OP:paddle.poisson * fix comment --- paddle/fluid/operators/poisson_op.cc | 132 +++++++++++++ paddle/fluid/operators/poisson_op.cu | 92 +++++++++ paddle/fluid/operators/poisson_op.h | 41 ++++ paddle/fluid/operators/uniform_random_op.cc | 5 +- paddle/scripts/paddle_build.sh | 2 +- python/paddle/__init__.py | 5 +- python/paddle/fluid/initializer.py | 4 +- .../tests/unittests/test_bernoulli_op.py | 8 +- .../fluid/tests/unittests/test_poisson_op.py | 181 ++++++++++++++++++ python/paddle/nn/initializer/dirac.py | 10 +- python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/random.py | 43 +++++ 12 files changed, 506 insertions(+), 18 deletions(-) create mode 100644 paddle/fluid/operators/poisson_op.cc create mode 100644 paddle/fluid/operators/poisson_op.cu create mode 100644 paddle/fluid/operators/poisson_op.h create mode 100644 python/paddle/fluid/tests/unittests/test_poisson_op.py diff --git a/paddle/fluid/operators/poisson_op.cc b/paddle/fluid/operators/poisson_op.cc new file mode 100644 index 00000000000..cc4b6e5e075 --- /dev/null +++ b/paddle/fluid/operators/poisson_op.cc @@ -0,0 +1,132 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include "paddle/fluid/operators/poisson_op.h" + +namespace paddle { +namespace operators { + +class PoissonOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "PoissonOp"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "PoissonOp"); + + auto dim = ctx->GetInputDim("X"); + ctx->SetOutputDim("Out", dim); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } +}; + +class PoissonOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(Tensor) The input tensor of poisson op"); + AddOutput("Out", + "The output tensor of poisson op, it has the same shape and " + "dtype with input. Each element corresponds to input tensor"); + AddComment(R"DOC( +This operator generate random value that obey poisson distribution. +)DOC"); + } +}; + +class PoissonOpInferVarType : public framework::PassInDtypeAndVarTypeToOutput { + protected: + std::unordered_map &GetInputOutputWithSameType() + const override { + static std::unordered_map m{{"X", /*->*/ "Out"}}; + return m; + } +}; + +template +class PoissonKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + const auto *x = ctx.Input("X"); + auto *out = ctx.Output("Out"); + + const T *x_data = x->data(); + T *out_data = out->mutable_data(ctx.GetPlace()); + + int64_t size = x->numel(); + + auto gen = framework::DefaultCPUGenerator(); + auto engine = gen->GetCPUEngine(); + + for (int64_t i = 0; i < size; ++i) { + std::poisson_distribution<> dist(x_data[i]); + out_data[i] = static_cast(dist(*engine)); + } + } +}; + +class PoissonGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext *ctx) const override { + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out_Grad", "PoissonGradOp"); + + auto dout_dim = ctx->GetInputDim(framework::GradVarName("Out")); + ctx->SetOutputDim(framework::GradVarName("X"), dout_dim); + } +}; + +template +class PoissonGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr retv) const override { + retv->SetType("poisson_grad"); + retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OPERATOR(poisson, ops::PoissonOp, ops::PoissonOpMaker, + ops::PoissonOpInferVarType, + ops::PoissonGradOpMaker, + ops::PoissonGradOpMaker); + +REGISTER_OPERATOR(poisson_grad, ops::PoissonGradOp); + +REGISTER_OP_CPU_KERNEL(poisson, + ops::PoissonKernel, + ops::PoissonKernel); + +REGISTER_OP_CPU_KERNEL(poisson_grad, + ops::PoissonGradKernel, + ops::PoissonGradKernel); diff --git a/paddle/fluid/operators/poisson_op.cu b/paddle/fluid/operators/poisson_op.cu new file mode 100644 index 00000000000..3f18eb994e1 --- /dev/null +++ b/paddle/fluid/operators/poisson_op.cu @@ -0,0 +1,92 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef __NVCC__ +#include +#endif +#ifdef __HIPCC__ +#include +#endif +#include "paddle/fluid/operators/poisson_op.h" +#include "paddle/fluid/platform/for_range.h" + +namespace paddle { +namespace operators { + +template +struct PoissonCudaFunctor { + public: + PoissonCudaFunctor(const T* in, T* out, unsigned int seed, + unsigned int offset) + : in_(in), out_(out), seed_(seed), offset_(offset) {} + + __device__ void operator()(int64_t idx) { +#ifdef __NVCC__ + curandStatePhilox4_32_10_t state; + curand_init(seed_, idx, offset_, &state); + out_[idx] = static_cast(curand_poisson(&state, in_[idx])); +#elif __HIPCC__ + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed_, idx, offset_, &state); + out_[idx] = static_cast(hiprand_poisson(&state, in_[idx])); +#endif + } + + private: + const T* in_; + T* out_; + const unsigned int seed_; + const unsigned int offset_; +}; + +template +class PoissonKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* x = ctx.Input("X"); + auto* out = ctx.Output("Out"); + + const T* x_data = x->data(); + T* out_data = out->mutable_data(ctx.GetPlace()); + auto size = x->numel(); + int64_t device_id = + BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId(); + + auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id); + auto seed_offset = gen_cuda->IncrementOffset(20); + uint64_t seed = seed_offset.first; + uint64_t offset = seed_offset.second; + + auto& dev_ctx = ctx.template device_context(); + platform::ForRange for_range(dev_ctx, size); + + PoissonCudaFunctor functor(x_data, out_data, seed, offset); + for_range(functor); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; + +REGISTER_OP_CUDA_KERNEL(poisson, + ops::PoissonKernel, + ops::PoissonKernel); + +REGISTER_OP_CUDA_KERNEL( + poisson_grad, ops::PoissonGradKernel, + ops::PoissonGradKernel); diff --git a/paddle/fluid/operators/poisson_op.h b/paddle/fluid/operators/poisson_op.h new file mode 100644 index 00000000000..2159637b290 --- /dev/null +++ b/paddle/fluid/operators/poisson_op.h @@ -0,0 +1,41 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/generator.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/operators/math/math_function.h" + +namespace paddle { +namespace operators { + +template +class PoissonKernel; + +template +class PoissonGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* dx = ctx.Output(framework::GradVarName("X")); + dx->mutable_data(ctx.GetPlace()); + math::SetConstant functor; + auto& dev_ctx = ctx.template device_context(); + functor(dev_ctx, dx, static_cast(0)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/uniform_random_op.cc b/paddle/fluid/operators/uniform_random_op.cc index 007276b16d7..cdb4ad7c408 100644 --- a/paddle/fluid/operators/uniform_random_op.cc +++ b/paddle/fluid/operators/uniform_random_op.cc @@ -27,7 +27,7 @@ namespace { template inline void UniformRealDistribution(T *data, const int64_t &size, const float &min, const float &max, - const unsigned int &seed) { + const unsigned int seed) { VLOG(4) << "[CPU] UniformRandomKernel"; std::uniform_real_distribution dist(static_cast(min), static_cast(max)); @@ -41,8 +41,7 @@ inline void UniformRealDistribution(T *data, const int64_t &size, template <> inline void UniformRealDistribution(paddle::platform::bfloat16 *data, const int64_t &size, const float &min, - const float &max, - const unsigned int &seed) { + const float &max, const unsigned int seed) { VLOG(4) << "[CPU] UniformRandomKernel"; std::uniform_real_distribution dist(min, max); auto engine = paddle::framework::GetCPURandomEngine(seed); diff --git a/paddle/scripts/paddle_build.sh b/paddle/scripts/paddle_build.sh index c58c78995e5..1c787f1f826 100755 --- a/paddle/scripts/paddle_build.sh +++ b/paddle/scripts/paddle_build.sh @@ -575,7 +575,7 @@ EOF export http_proxy= export https_proxy= set -x - + set +ex if [ "$1" == "cp36-cp36m" ]; then pip3.6 uninstall -y paddlepaddle diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 24607319f30..a473a12e240 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -64,8 +64,6 @@ import paddle.reader # noqa: F401 import paddle.static # noqa: F401 import paddle.vision # noqa: F401 -from .tensor.random import bernoulli # noqa: F401 - from .tensor.attribute import is_complex # noqa: F401 from .tensor.attribute import is_integer # noqa: F401 from .tensor.attribute import rank # noqa: F401 @@ -248,6 +246,8 @@ from .tensor.math import angle # noqa: F401 from .tensor.math import fmax # noqa: F401 from .tensor.math import fmin # noqa: F401 +from .tensor.random import bernoulli # noqa: F401 +from .tensor.random import poisson # noqa: F401 from .tensor.random import multinomial # noqa: F401 from .tensor.random import standard_normal # noqa: F401 from .tensor.random import normal # noqa: F401 @@ -488,6 +488,7 @@ __all__ = [ # noqa 'exp', 'expm1', 'bernoulli', + 'poisson', 'sinh', 'round', 'DataParallel', diff --git a/python/paddle/fluid/initializer.py b/python/paddle/fluid/initializer.py index 27e4ef6fe28..d9c67653320 100644 --- a/python/paddle/fluid/initializer.py +++ b/python/paddle/fluid/initializer.py @@ -1152,12 +1152,12 @@ def calculate_gain(nonlinearity, param=None): Args: nonlinearity(str): name of nonlinearity activation function. If it is a linear function, which is one of - "linear/conv1d/conv2d/conv3d/conv1d_transpose/conv2d_transpose/conv3d_transpose" , will return 1.0 + "linear/conv1d/conv2d/conv3d/conv1d_transpose/conv2d_transpose/conv3d_transpose" , 1.0 will be returned. param(bool|int|float, optional): optional parameter for somme nonlinearity function. Now, it only applies to 'leaky_relu'. Default: None, it will be calculated as 0.01 in the formula. Returns: - The recommended gain value for nonlinearity function. + A float value, which is the recommended gain for this nonlinearity function. Examples: .. code-block:: python diff --git a/python/paddle/fluid/tests/unittests/test_bernoulli_op.py b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py index 12a29de8042..471caeb77bf 100644 --- a/python/paddle/fluid/tests/unittests/test_bernoulli_op.py +++ b/python/paddle/fluid/tests/unittests/test_bernoulli_op.py @@ -32,18 +32,14 @@ class TestBernoulliOp(OpTest): def setUp(self): self.op_type = "bernoulli" self.inputs = {"X": np.random.uniform(size=(1000, 784))} - self.init_attrs() - self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} - - def init_attrs(self): self.attrs = {} - self.output_hist = output_hist + self.outputs = {"Out": np.zeros((1000, 784)).astype("float32")} def test_check_output(self): self.check_output_customized(self.verify_output) def verify_output(self, outs): - hist, prob = self.output_hist(np.array(outs[0])) + hist, prob = output_hist(np.array(outs[0])) self.assertTrue( np.allclose( hist, prob, rtol=0, atol=0.01), "hist: " + str(hist)) diff --git a/python/paddle/fluid/tests/unittests/test_poisson_op.py b/python/paddle/fluid/tests/unittests/test_poisson_op.py new file mode 100644 index 00000000000..854aaf88547 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_poisson_op.py @@ -0,0 +1,181 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import paddle +import numpy as np +from op_test import OpTest +import math + +paddle.enable_static() + + +def output_hist(out, lam, a, b): + prob = [] + bin = [] + for i in range(a, b + 1): + prob.append((lam**i) * math.exp(-lam) / math.factorial(i)) + bin.append(i) + bin.append(b + 0.1) + + hist, _ = np.histogram(out, bin) + hist = hist.astype("float32") + hist = hist / float(out.size) + return hist, prob + + +class TestPoissonOp1(OpTest): + def setUp(self): + self.op_type = "poisson" + self.config() + + self.attrs = {} + self.inputs = {'X': np.full([1024, 1024], self.lam, dtype=self.dtype)} + self.outputs = {'Out': np.ones([1024, 1024], dtype=self.dtype)} + + def config(self): + self.lam = 10 + self.a = 5 + self.b = 15 + self.dtype = "float64" + + def verify_output(self, outs): + hist, prob = output_hist(np.array(outs[0]), self.lam, self.a, self.b) + self.assertTrue( + np.allclose( + hist, prob, rtol=0.01), + "actual: {}, expected: {}".format(hist, prob)) + + def test_check_output(self): + self.check_output_customized(self.verify_output) + + def test_check_grad_normal(self): + self.check_grad( + ['X'], + 'Out', + user_defined_grads=[np.zeros( + [1024, 1024], dtype=self.dtype)], + user_defined_grad_outputs=[ + np.random.rand(1024, 1024).astype(self.dtype) + ]) + + +class TestPoissonOp2(TestPoissonOp1): + def config(self): + self.lam = 5 + self.a = 1 + self.b = 9 + self.dtype = "float32" + + +class TestPoissonAPI(unittest.TestCase): + def test_static(self): + with paddle.static.program_guard(paddle.static.Program(), + paddle.static.Program()): + x_np = np.random.rand(10, 10) + x = paddle.static.data(name="x", shape=[10, 10], dtype='float64') + y = paddle.poisson(x) + + exe = paddle.static.Executor() + y_np = exe.run(paddle.static.default_main_program(), + feed={"x": x_np}, + fetch_list=[y]) + self.assertTrue(np.min(y_np) >= 0) + + def test_dygraph(self): + paddle.disable_static() + x = paddle.randn([10, 10], dtype='float32') + y = paddle.poisson(x) + self.assertTrue(np.min(y.numpy()) >= 0) + paddle.enable_static() + + # Test GPU Fixed random number, which is generated by 'curandStatePhilox4_32_10_t' + def test_fixed_random_number(self): + if not paddle.is_compiled_with_cuda(): + return + + paddle.disable_static() + paddle.set_device('gpu') + paddle.seed(2021) + x = paddle.full([32, 3, 1024, 768], 10., dtype="float32") + y = paddle.poisson(x) + y_np = y.numpy() + + expect = [ + 13., 13., 11., 8., 12., 6., 9., 15., 16., 6., 13., 12., 9., 15., + 17., 8., 11., 16., 11., 10. + ] + self.assertTrue(np.array_equal(y_np[0, 0, 0, 0:20], expect)) + + expect = [ + 15., 7., 12., 8., 14., 10., 10., 11., 11., 11., 21., 6., 9., 13., + 13., 11., 6., 9., 12., 12. + ] + self.assertTrue(np.array_equal(y_np[8, 1, 300, 200:220], expect)) + + expect = [ + 10., 15., 9., 6., 4., 13., 10., 10., 13., 12., 9., 7., 10., 14., 7., + 10., 8., 5., 10., 14. + ] + self.assertTrue(np.array_equal(y_np[16, 1, 600, 400:420], expect)) + + expect = [ + 10., 9., 14., 12., 8., 9., 7., 8., 11., 10., 13., 8., 12., 9., 7., + 8., 11., 11., 12., 5. + ] + self.assertTrue(np.array_equal(y_np[24, 2, 900, 600:620], expect)) + + expect = [ + 15., 5., 11., 13., 12., 12., 13., 16., 9., 9., 7., 9., 13., 11., + 15., 6., 11., 9., 10., 10. + ] + self.assertTrue(np.array_equal(y_np[31, 2, 1023, 748:768], expect)) + + x = paddle.full([16, 1024, 1024], 5., dtype="float32") + y = paddle.poisson(x) + y_np = y.numpy() + expect = [ + 4., 5., 2., 9., 8., 7., 4., 7., 4., 7., 6., 3., 10., 7., 5., 7., 2., + 5., 5., 6. + ] + self.assertTrue(np.array_equal(y_np[0, 0, 100:120], expect)) + + expect = [ + 1., 4., 8., 11., 6., 5., 4., 4., 7., 4., 4., 7., 11., 6., 5., 3., + 4., 6., 3., 3. + ] + self.assertTrue(np.array_equal(y_np[4, 300, 300:320], expect)) + + expect = [ + 7., 5., 4., 6., 8., 5., 6., 7., 7., 7., 3., 10., 5., 10., 4., 5., + 8., 7., 5., 7. + ] + self.assertTrue(np.array_equal(y_np[8, 600, 600:620], expect)) + + expect = [ + 8., 6., 7., 4., 3., 0., 4., 6., 6., 4., 3., 10., 5., 1., 3., 8., 8., + 2., 1., 4. + ] + self.assertTrue(np.array_equal(y_np[12, 900, 900:920], expect)) + + expect = [ + 2., 1., 14., 3., 6., 5., 2., 2., 6., 5., 7., 4., 8., 4., 8., 4., 5., + 7., 1., 7. + ] + self.assertTrue(np.array_equal(y_np[15, 1023, 1000:1020], expect)) + paddle.enable_static() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/nn/initializer/dirac.py b/python/paddle/nn/initializer/dirac.py index 55765782e5a..267d5be5be2 100644 --- a/python/paddle/nn/initializer/dirac.py +++ b/python/paddle/nn/initializer/dirac.py @@ -27,11 +27,13 @@ class Dirac(Initializer): as many channels are reserved as possible. In this initialize method, elements in the middle of convolution kernels will - be set to 1 . The formula can be described as: + be set to 1 . The formula can be described as follow. - $ Assuming: N=min(in\_channels, out\_channels)$ + .. math:: - $ X[d, d, shape[2]//2, shape[3]//2, ...]=1, \ d=0,1...N$ + Assuming: N=min(in\_channels, out\_channels) + + X[d, d, shape[2]//2, shape[3]//2, ...]=1, \ d=0,1...N Args: groups(int): 0-dimension of the Tensor will be divided by groups, each group has the same value. @@ -46,7 +48,7 @@ class Dirac(Initializer): import paddle - #1.For kernel_size is uneven number: + #1. For kernel_size is uneven number: attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Dirac()) conv = paddle.nn.Conv1D(3, 2, 3, weight_attr=attr) diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index fcb328ba276..a5d119a8d1a 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -225,6 +225,7 @@ from .random import rand # noqa: F401 from .random import randint # noqa: F401 from .random import randint_like # noqa: F401 from .random import randperm # noqa: F401 +from .random import poisson # noqa: F401 from .search import argmax # noqa: F401 from .search import argmin # noqa: F401 from .search import argsort # noqa: F401 diff --git a/python/paddle/tensor/random.py b/python/paddle/tensor/random.py index f50ad4309ea..55ca6a0d9ce 100644 --- a/python/paddle/tensor/random.py +++ b/python/paddle/tensor/random.py @@ -79,6 +79,49 @@ def bernoulli(x, name=None): return out +def poisson(x, name=None): + """ + This OP returns a tensor filled with random number from a Poisson Distribution. + + .. math:: + + out_i ~ Poisson (x_i) + + Args: + x(Tensor): A tensor with rate parameter of poisson Distribution. The data type + should be float32, float64. + name(str, optional): The default value is None. Normally there is no + need for user to set this property. For more information, please + refer to :ref:`api_guide_Name`. + Returns: + Tensor: A Tensor filled with random number with the same shape and dtype as ``x``. + + Examples: + .. code-block:: python + + import paddle + paddle.set_device('gpu') + paddle.seed(2021) + + x = paddle.uniform([2,3], min=1.0, max=5.0) + out = paddle.poisson(x) + # [[0., 5., 1.], + # [4., 3., 0.]]) + + """ + + if in_dygraph_mode(): + return _C_ops.poisson(x) + + check_variable_and_dtype(x, "x", ["float32", "float64"], "poisson") + + helper = LayerHelper("poisson", **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='poisson', inputs={'X': x}, outputs={'Out': out}, attrs={}) + return out + + def multinomial(x, num_samples=1, replacement=False, name=None): """ This OP returns a Tensor filled with random values sampled from a Multinomical -- GitLab