未验证 提交 be29b8ee 编写于 作者: J JYChen 提交者: GitHub

add uniform_ op and UT (#33934)

上级 5a72cf43
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class UniformRandomInplaceOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(
This operator fills self tensor with random values sampled from a
uniform distribution. The random result is in a range of [min, max).
)DOC");
AddInput("X", "The input tensor.");
AddOutput("Out", "The output tensor of uniform random op");
AddAttr<float>("min", "Minimum value of uniform random. [default -1.0].")
.SetDefault(-1.0f);
AddAttr<float>("max", "Maximun value of uniform random. [default 1.0].")
.SetDefault(1.0f);
AddAttr<int>("seed",
"Random seed used for generating samples. "
"If seed is 0, it will use the seed of the global default "
"generator (which can be set by paddle.seed). "
"Note that if seed is not 0, this operator will always "
"generate the same random numbers every time. [default 0].")
.SetDefault(0);
AddAttr<int>("diag_num",
"The number of diag elements. Note that if "
"diag_num is 0, it means without diag init.[default 0].")
.SetDefault(0);
AddAttr<int>("diag_step", "The step between two diag element.[default 0].")
.SetDefault(0);
AddAttr<float>("diag_val", "The value of diag element. [default 1.0].")
.SetDefault(1.0f);
}
};
class UniformRandomInplaceOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "UniformRandomInplaceOp");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out",
"UniformRandomInplaceOp");
PADDLE_ENFORCE_LT(
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max"),
platform::errors::InvalidArgument(
"The uniform_random's min must less then max. But received min = "
"%f great than or equal max = %f.",
ctx->Attrs().Get<float>("min"), ctx->Attrs().Get<float>("max")));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_num"), 0,
platform::errors::InvalidArgument(
"The uniform_random's diag_num must greater than or "
"equal 0. But recevied diag_num (%d) < 0.",
ctx->Attrs().Get<int>("diag_num")));
PADDLE_ENFORCE_GE(ctx->Attrs().Get<int>("diag_step"), 0,
platform::errors::InvalidArgument(
"The uniform_random's diag_step must greater than or "
"equal 0. But recevied diag_step (%d) < 0.",
ctx->Attrs().Get<int>("diag_step")));
auto xdim = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", xdim);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace());
}
};
template <typename T>
class CPUUniformRandomInplaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto out_var = ctx.OutputVar("Out");
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
T *data = tensor->mutable_data<T>(ctx.GetPlace());
int64_t size = tensor->numel();
std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max")));
auto engine = paddle::framework::GetCPURandomEngine(
static_cast<unsigned int>(ctx.Attr<int>("seed")));
for (int64_t i = 0; i < size; ++i) {
data[i] = dist(*engine);
}
}
};
class UniformRandomInplaceOpVarTypeInference
: public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {}
};
class UniformRandomInplaceGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out_Grad", "UniformRandomInplaceGradOp");
auto x_dims = ctx->GetInputDim(framework::GradVarName("Out"));
auto x_grad_name = framework::GradVarName("X");
if (ctx->HasOutput(x_grad_name)) {
ctx->SetOutputDim(x_grad_name, x_dims);
}
}
};
template <typename T>
class UniformRandomInplaceGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> retv) const override {
retv->SetType(this->ForwardOpType() + "_grad");
retv->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
retv->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
retv->SetAttrMap(this->Attrs());
}
};
template <typename T>
class CPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
if (dx) {
auto *data = dx->mutable_data<T>(ctx.GetPlace());
std::fill(data, data + dx->numel(), T(0));
}
}
};
} // namespace operators
} // namespace paddle
DECLARE_INPLACE_OP_INFERER(UniformRandomInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(UniformRandomInplaceGradInplaceInferer,
{paddle::framework::GradVarName("Out"),
paddle::framework::GradVarName("X")});
REGISTER_OPERATOR(uniform_random_inplace,
paddle::operators::UniformRandomInplaceOp,
paddle::operators::UniformRandomInplaceOpMaker,
paddle::operators::UniformRandomInplaceGradOpMaker<
paddle::framework::OpDesc>,
paddle::operators::UniformRandomInplaceGradOpMaker<
paddle::imperative::OpBase>,
paddle::operators::UniformRandomInplaceOpVarTypeInference,
UniformRandomInplaceInferer);
REGISTER_OPERATOR(uniform_random_inplace_grad,
paddle::operators::UniformRandomInplaceGradOp,
UniformRandomInplaceGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
uniform_random_inplace,
paddle::operators::CPUUniformRandomInplaceKernel<float>,
paddle::operators::CPUUniformRandomInplaceKernel<double>);
REGISTER_OP_CPU_KERNEL(
uniform_random_inplace_grad,
paddle::operators::CPUUniformRandomInplaceGradKernel<float>,
paddle::operators::CPUUniformRandomInplaceGradKernel<double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
template <typename T>
struct UniformGenerator {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
__host__ __device__ UniformGenerator(T min, T max, int seed, int diag_num,
int diag_step, T diag_val)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
template <typename T>
struct UniformGeneratorOffset {
T min_, max_;
unsigned int seed_;
T diag_val_;
unsigned int diag_num_;
unsigned int diag_step_;
int offset_;
__host__ __device__ UniformGeneratorOffset(T min, T max, int seed,
int diag_num, int diag_step,
T diag_val, int offset)
: min_(min),
max_(max),
seed_(seed),
diag_num_(diag_num),
diag_step_(diag_step),
diag_val_(diag_val),
offset_(offset) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
thrust::uniform_real_distribution<T> dist(min_, max_);
rng.discard(n + offset_);
T out = dist(rng);
unsigned int remainder = n % (diag_step_ + 1);
if (remainder == 0 && diag_num_ > n / (diag_step_ + 1)) {
out = diag_val_;
}
return out;
}
};
template <typename T>
__global__ void fill_value(int64_t size, T* data, float value) {
for (int idx = threadIdx.x; idx < size; idx += blockDim.x) {
data[idx] = static_cast<T>(value);
}
}
// It seems that Eigen::Tensor::random in GPU will SEGFAULT.
// Use std::random and thrust::random(thrust is a std library in CUDA) to
// implement uniform random as uniform_random_op.cu.
template <typename T>
class GPUUniformRandomInplaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto out_var = ctx.OutputVar("Out");
auto* tensor = out_var->GetMutable<framework::LoDTensor>();
T* data = tensor->mutable_data<T>(ctx.GetPlace());
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
bool seed_flag = false;
if (seed == 0) {
std::random_device rd;
seed = rd();
seed_flag = true;
}
T min = static_cast<T>(ctx.Attr<float>("min"));
T max = static_cast<T>(ctx.Attr<float>("max"));
unsigned int diag_num =
static_cast<unsigned int>(ctx.Attr<int>("diag_num"));
unsigned int diag_step =
static_cast<unsigned int>(ctx.Attr<int>("diag_step"));
T diag_val = static_cast<T>(ctx.Attr<float>("diag_val"));
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = tensor->numel();
int device_id =
BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()).GetDeviceId();
auto gen_cuda = framework::GetDefaultCUDAGenerator(device_id);
if (gen_cuda->GetIsInitPy() && seed_flag) {
auto seed_offset = gen_cuda->IncrementOffset(1);
int gen_offset = size * seed_offset.second;
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGeneratorOffset<T>(min, max, seed_offset.first, diag_num,
diag_step, diag_val, gen_offset));
} else {
thrust::transform(
index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformGenerator<T>(min, max, seed, diag_num, diag_step, diag_val));
}
}
};
template <typename T>
class GPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
#ifdef __HIPCC__
const int64_t kMaxBlockDim = 256;
#else
const int64_t kMaxBlockDim = 512;
#endif
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* data = dx->mutable_data<T>(ctx.GetPlace());
auto size = dx->numel();
int64_t kBlockDim = std::min(size, kMaxBlockDim);
fill_value<T><<<1, kBlockDim, 0>>>(size, data, static_cast<float>(0));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(
uniform_random_inplace,
paddle::operators::GPUUniformRandomInplaceKernel<float>,
paddle::operators::GPUUniformRandomInplaceKernel<double>);
REGISTER_OP_CUDA_KERNEL(
uniform_random_inplace_grad,
paddle::operators::GPUUniformRandomInplaceGradKernel<float>,
paddle::operators::GPUUniformRandomInplaceGradKernel<double>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/uniform_random_op.h"
namespace paddle {
namespace operators {
template <typename T>
class XPUUniformRandomInplaceKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto out_var = ctx.OutputVar("Out");
auto *tensor = out_var->GetMutable<framework::LoDTensor>();
T *data = tensor->mutable_data<T>(ctx.GetPlace());
int64_t size = tensor->numel();
std::unique_ptr<T[]> data_cpu(new T[size]);
std::uniform_real_distribution<T> dist(
static_cast<T>(ctx.Attr<float>("min")),
static_cast<T>(ctx.Attr<float>("max")));
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
auto engine = framework::GetCPURandomEngine(seed);
for (int64_t i = 0; i < size; ++i) {
data_cpu[i] = dist(*engine);
}
unsigned int diag_num =
static_cast<unsigned int>(ctx.Attr<int>("diag_num"));
unsigned int diag_step =
static_cast<unsigned int>(ctx.Attr<int>("diag_step"));
auto diag_val = static_cast<T>(ctx.Attr<float>("diag_val"));
if (diag_num > 0) {
PADDLE_ENFORCE_GT(
size, (diag_num - 1) * (diag_step + 1),
platform::errors::InvalidArgument(
"ShapeInvalid: the diagonal's elements is equal (num-1) "
"* (step-1) with num %d, step %d,"
"It should be smaller than %d, but received %d",
diag_num, diag_step, (diag_num - 1) * (diag_step + 1), size));
for (int64_t i = 0; i < diag_num; ++i) {
int64_t pos = i * diag_step + i;
data_cpu[pos] = diag_val;
}
}
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), data,
platform::CPUPlace(), reinterpret_cast<void *>(data_cpu.get()),
size * sizeof(T));
}
};
template <typename T>
class XPUUniformRandomInplaceGradKernel : public framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext &ctx) const override {
auto *dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
if (dx) {
T *data = dx->mutable_data<T>(ctx.GetPlace());
int64_t size = dx->numel();
std::unique_ptr<T[]> data_cpu(new T[size]);
for (int64_t i = 0; i < size; ++i) {
data_cpu[i] = T(0);
}
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, ctx.GetPlace()), data,
platform::CPUPlace(),
reinterpret_cast<void *>(data_cpu.get()), size * sizeof(T));
}
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_XPU_KERNEL(uniform_random_inplace,
paddle::operators::XPUUniformRandomInplaceKernel<float>);
REGISTER_OP_XPU_KERNEL(
uniform_random_inplace_grad,
paddle::operators::XPUUniformRandomInplaceGradKernel<float>);
#endif // PADDLE_WITH_XPU
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.fluid as fluid
import numpy as np
class TestUniformRandomInplaceOpDtype(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)
def test_uniform_random_inplace_op_dtype(self):
def test_fp32():
tensor_fp32 = paddle.ones(self.shape, dtype=paddle.float32)
tensor_fp32.uniform_()
self.assertEqual(tensor_fp32.dtype, paddle.float32)
def test_fp64():
tensor_fp64 = paddle.ones(self.shape, paddle.float64)
tensor_fp64.uniform_()
self.assertEqual(tensor_fp64.dtype, paddle.float64)
places = ['cpu']
if fluid.core.is_compiled_with_cuda():
places.append('gpu')
for place in places:
paddle.set_device(place)
test_fp32()
test_fp64()
class TestUniformRandomInplaceOpIsInplace(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)
def test_uniform_random_inplace_op_is_inplace(self):
tensor_a = paddle.ones(self.shape)
tensor_b = tensor_a.uniform_()
self.assertTrue(tensor_a is tensor_b)
class TestUniformRandomInplaceOpSeedIsZero(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)
self.seed = 0
def test_uniform_random_inplace_op_seed_is_zero(self):
tensor = paddle.ones(self.shape)
tensor.uniform_(seed=self.seed)
tensor_data_first = tensor.numpy()
tensor.uniform_(seed=self.seed)
tensor_data_second = tensor.numpy()
self.assertFalse((tensor_data_first == tensor_data_second).all())
class TestUniformRandomInplaceOpSeedIsNotZero(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)
self.seed = 10
def test_uniform_random_inplace_op_seed_is_not_zero(self):
tensor = paddle.ones(self.shape)
tensor.uniform_(seed=self.seed)
tensor_data_first = tensor.numpy()
tensor.uniform_(seed=self.seed)
tensor_data_second = tensor.numpy()
self.assertTrue((tensor_data_first == tensor_data_second).all())
class TestUniformRandomInplaceOpWithinRange(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)
self.min = -2
self.max = 1
self.seed = 10
def test_uniform_random_inplace_op_within_range(self):
tensor = paddle.ones(self.shape)
tensor.uniform_(min=self.min, max=self.max, seed=self.seed)
tensor_data = tensor.numpy()
self.assertTrue((tensor_data > self.min).all() and
(tensor_data < self.max).all())
class TestUniformRandomInplaceOpShape(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)
def test_uniform_random_inplace_op_shape(self):
tensor = paddle.ones(self.shape)
tensor.uniform_()
tensor_shape_np = np.array(tensor.shape)
origin_shape = np.array(self.shape)
self.assertTrue((tensor_shape_np == origin_shape).all())
class TestUniformRandomInplaceOpDistribution(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)
self.min = -3
self.max = 5
self.seed = 10
self.bins = 100
def test_uniform_random_inplace_op_distribution(self):
tensor = paddle.ones(self.shape)
tensor.uniform_(self.min, self.max, self.seed)
hist, _ = np.histogram(tensor.numpy()[0], bins=self.bins)
prob = hist / float(self.shape[0])
prob_expect = np.ones((self.bins, )) / float(self.bins)
self.assertTrue(np.allclose(prob, prob_expect, rtol=0, atol=1e-2))
class TestUniformRandomInplaceOpError(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)
def test_uniform_random_inplace_op_error(self):
def test_attr_error():
tensor = paddle.ones(self.shape)
tensor.uniform_(shape=self.shape, min=-2, max=2)
self.assertRaises(TypeError, test_attr_error)
class TestUniformRandomInplaceOpEmptyTensor(unittest.TestCase):
def test_uniform_random_inplace_op_empty_tensor(self):
places = ['cpu']
if fluid.core.is_compiled_with_cuda():
places.append('gpu')
test_shapes = [(200, 0), (0, 200)]
for place in places:
paddle.set_device(place)
for test_shape in test_shapes:
tensor = paddle.empty(shape=test_shape)
tensor.uniform_()
tensor_shape_np = np.array(tensor.shape)
origin_shape = np.array(test_shape)
self.assertTrue((tensor_shape_np == origin_shape).all())
class TestUniformRandomInplaceGrad(unittest.TestCase):
def setUp(self):
self.shape = (1000, 784)
def test_uniform_random_inplace_grad(self):
def test_grad():
tensor_a = paddle.ones(self.shape)
tensor_a.stop_gradient = False
tensor_b = tensor_a * 0.5
tensor_b.uniform_(min=-2, max=2)
loss = tensor_b.sum()
loss.backward()
uniform_grad = tensor_b.grad.numpy()
self.assertTrue((uniform_grad == 0).all())
places = ['cpu']
if fluid.core.is_compiled_with_cuda():
places.append('gpu')
for place in places:
paddle.set_device(place)
test_grad()
if __name__ == '__main__':
unittest.main()
......@@ -180,6 +180,7 @@ from .random import multinomial # noqa: F401
from .random import standard_normal # noqa: F401
from .random import normal # noqa: F401
from .random import uniform # noqa: F401
from .random import uniform_ # noqa: F401
from .random import randn # noqa: F401
from .random import rand # noqa: F401
from .random import randint # noqa: F401
......@@ -371,6 +372,7 @@ tensor_method_func = [ #noqa
'bitwise_xor',
'bitwise_not',
'broadcast_tensors',
'uniform_',
]
#this list used in math_op_patch.py for magic_method bind
......
......@@ -15,7 +15,7 @@
# TODO: define random functions
from ..fluid import core
from ..fluid.framework import in_dygraph_mode, Variable, convert_np_dtype_to_dtype_
from ..fluid.framework import in_dygraph_mode, Variable, convert_np_dtype_to_dtype_, dygraph_only
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import check_variable_and_dtype, check_type, check_dtype, check_shape
from ..fluid.layers import utils
......@@ -444,9 +444,9 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
to generate, ``min`` is included in the range. Default is -1.0.
max(float|int, optional): The upper bound on the range of random values
to generate, ``max`` is excluded in the range. Default is 1.0.
seed(int, optional): Random seed used for generating samples. 0 means
use a seed generated by the system. Note that if seed is not 0,
this operator will always generate the same random numbers every
seed(int, optional): Random seed used for generating samples. If seed is 0,
it will use the seed of the global default generator (which can be set by paddle.seed).
Note that if seed is not 0, this operator will always generate the same random numbers every
time. Default is 0.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
......@@ -520,6 +520,45 @@ def uniform(shape, dtype=None, min=-1.0, max=1.0, seed=0, name=None):
return out
@dygraph_only
def uniform_(x, min=-1.0, max=1.0, seed=0, name=None):
"""
This is the inplace version of OP ``uniform``, which returns a Tensor filled
with random values sampled from a uniform distribution. The output Tensor will
be inplaced with input ``x``. Please refer to :ref:`api_tensor_uniform`.
Args:
x(Tensor): The input tensor to be filled with random values.
min(float|int, optional): The lower bound on the range of random values
to generate, ``min`` is included in the range. Default is -1.0.
max(float|int, optional): The upper bound on the range of random values
to generate, ``max`` is excluded in the range. Default is 1.0.
seed(int, optional): Random seed used for generating samples. If seed is 0,
it will use the seed of the global default generator (which can be set by paddle.seed).
Note that if seed is not 0, this operator will always generate the same random numbers every
time. Default is 0.
name(str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: The input tensor x filled with random values sampled from a uniform
distribution in the range [``min``, ``max``).
Examples:
.. code-block:: python
import paddle
# example:
x = paddle.ones(shape=[3, 4])
x.uniform_()
print(x)
# [[ 0.84524226, 0.6921872, 0.56528175, 0.71690357], # random
# [-0.34646994, -0.45116323, -0.09902662, -0.11397249], # random
# [ 0.433519, 0.39483607, -0.8660099, 0.83664286]] # random
"""
return core.ops.uniform_random_inplace_(x, 'min', min, 'max', max, 'seed',
seed)
def randint(low=0, high=None, shape=[1], dtype=None, name=None):
"""
This OP returns a Tensor filled with random integers from a discrete uniform
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册