未验证 提交 009c049e 编写于 作者: S silingtong123 提交者: GitHub

Add randint op API (#23337)

* add randint op
上级 ea6a251c
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/uniform_random_op.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace operators {
template <typename T>
class CPURandintKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
std::vector<int64_t> new_shape;
auto list_new_shape_tensor =
ctx.MultiInput<framework::Tensor>("ShapeTensorList");
if (list_new_shape_tensor.size() > 0 || ctx.HasInput("ShapeTensor")) {
if (ctx.HasInput("ShapeTensor")) {
auto* shape_tensor = ctx.Input<framework::Tensor>("ShapeTensor");
new_shape = GetNewDataFromShapeTensor(shape_tensor);
} else if (list_new_shape_tensor.size() > 0) {
new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
}
}
auto* out = ctx.Output<framework::LoDTensor>("Out");
if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape));
T* data = out->mutable_data<T>(ctx.GetPlace());
int64_t size = out->numel();
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> dist(ctx.Attr<int>("low"),
ctx.Attr<int>("high") - 1);
for (int64_t i = 0; i < size; ++i) data[i] = dist(gen);
}
};
class RandintOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument("Output(Out) of RandintOp is null."));
PADDLE_ENFORCE_LT(
ctx->Attrs().Get<int>("low"), ctx->Attrs().Get<int>("high"),
platform::errors::InvalidArgument("randint's low must less then high, "
"but received: low = %d, high = %d.",
ctx->Attrs().Get<int>("low"),
ctx->Attrs().Get<int>("high")));
if (ctx->HasInputs("ShapeTensorList")) {
// top prority shape
auto inputs_name = ctx->Inputs("ShapeTensorList");
PADDLE_ENFORCE_GT(
inputs_name.size(), 0,
platform::errors::InvalidArgument(
"Input(ShapeTensorList)'size of Op(randint) can't be zero."
"Please check the Attr(shape)'s size of"
"Op(fluid.layers.randint).)"));
auto out_dims = std::vector<int>(inputs_name.size(), -1);
ctx->SetOutputDim("Out", framework::make_ddim(out_dims));
return;
}
auto& shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
if (ctx->HasInput("ShapeTensor") && shape.empty()) {
auto shape_dims = ctx->GetInputDim("ShapeTensor");
PADDLE_ENFORCE_EQ(shape_dims.size(), 1,
platform::errors::InvalidArgument(
"ShapeError: Input(ShapeTensor)' dimension size of "
"Op(randint) must be 1."
"But received ShapeTensor's dimensions = %d.",
shape_dims.size()));
int num_ele = 1;
for (int i = 0; i < shape_dims.size(); ++i) {
num_ele *= shape_dims[i];
}
auto vec_dims = std::vector<int64_t>(num_ele, -1);
auto out_dims = framework::make_ddim(vec_dims);
ctx->SetOutputDim("Out", out_dims);
return;
}
PADDLE_ENFORCE_EQ(shape.empty(), false,
platform::errors::InvalidArgument(
"if there is no Input(ShapeTensorList) and no "
"Input(ShapeTensor),the "
"attr(shape) information must "
"be set by Attr(shape)."));
std::vector<int64_t> tensor_shape;
tensor_shape.reserve(shape.size());
for (auto dim : shape) {
tensor_shape.push_back(static_cast<int64_t>(dim));
}
ctx->SetOutputDim("Out", framework::make_ddim(tensor_shape));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype")),
ctx.GetPlace());
}
};
class RandintOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("ShapeTensor",
"(Tensor<int64_t> or Tensor<int32_t>, optional) . If provided, "
"randint"
"according to "
"this given shape. It means that it has a higher priority than "
"Attr(shape) but a lower priority than Input(ShapeTensor).")
.AsDispensable();
AddInput("ShapeTensorList",
"(vector<Tensor<int64_t>> or vector<Tensor<int32_t>>, optional). "
"If provided, randint use this. The shape of the tensor "
"must be [1], it has the highest priority comparing with "
"Input(ShapeTensor) and attr(shape).")
.AsDuplicable()
.AsDispensable();
AddOutput("Out", "The output tensor of randint op");
AddComment(R"DOC(
This operator initializes a tensor with random integers sampled from a
uniform distribution. The random result is in set [low, high).
)DOC");
AddAttr<std::vector<int64_t>>("shape", "The shape of the output tensor.")
.SetDefault({});
AddAttr<int>("low",
"The lower bound on the range of random values to generate.");
AddAttr<int>("high",
"The upper bound on the range of random values to generate.");
AddAttr<int>("dtype", "Output tensor data type. [Default INT64].")
.SetDefault(framework::proto::VarType::INT64);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
randint, ops::RandintOp, ops::RandintOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>)
REGISTER_OP_CPU_KERNEL(randint, ops::CPURandintKernel<int>,
ops::CPURandintKernel<int64_t>)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <thrust/random.h>
#include <thrust/transform.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/uniform_random_op.h"
namespace paddle {
namespace operators {
template <typename T>
struct UniformIntGenerator {
T low_, high_;
__host__ __device__ UniformIntGenerator(T low, T high)
: low_(low), high_(high) {}
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(0);
thrust::uniform_int_distribution<T> dist(low_, high_);
rng.discard(n);
T out = dist(rng);
return out;
}
};
// Use std::uniform_int_distribution and thrust::uniform_int_distribution(thrust
// is a std library in CUDA) to
// implement randint.
template <typename T>
class GPURandintKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
std::vector<int64_t> new_shape;
auto list_new_shape_tensor =
context.MultiInput<framework::Tensor>("ShapeTensorList");
if (list_new_shape_tensor.size() > 0 || context.HasInput("ShapeTensor")) {
if (context.HasInput("ShapeTensor")) {
auto* shape_tensor = context.Input<framework::Tensor>("ShapeTensor");
new_shape = GetNewDataFromShapeTensor(shape_tensor);
} else if (list_new_shape_tensor.size() > 0) {
new_shape = GetNewDataFromShapeTensorList(list_new_shape_tensor);
}
}
auto* out = context.Output<framework::LoDTensor>("Out");
if (!new_shape.empty()) out->Resize(framework::make_ddim(new_shape));
T* data = out->mutable_data<T>(context.GetPlace());
T low = static_cast<T>(context.Attr<int>("low"));
T high = static_cast<T>(context.Attr<int>("high")) - 1;
thrust::counting_iterator<unsigned int> index_sequence_begin(0);
int64_t size = out->numel();
thrust::transform(index_sequence_begin, index_sequence_begin + size,
thrust::device_ptr<T>(data),
UniformIntGenerator<T>(low, high));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(randint, ops::GPURandintKernel<int>,
ops::GPURandintKernel<int64_t>)
......@@ -90,7 +90,7 @@ from .tensor.logic import elementwise_equal #DEFINE_ALIAS
# from .tensor.random import randn #DEFINE_ALIAS
from .tensor.random import randperm
# from .tensor.random import rand #DEFINE_ALIAS
# from .tensor.random import randint #DEFINE_ALIAS
from .tensor.random import randint #DEFINE_ALIAS
# from .tensor.math import abs #DEFINE_ALIAS
# from .tensor.math import acos #DEFINE_ALIAS
# from .tensor.math import asin #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import paddle
def output_hist(out):
hist, _ = np.histogram(out, range=(-5, 10))
hist = hist.astype("float32")
hist /= float(out.size)
prob = 0.1 * np.ones((10))
return hist, prob
class TestRandintOp(OpTest):
def setUp(self):
self.op_type = "randint"
self.inputs = {}
self.init_attrs()
self.outputs = {"Out": np.zeros((10000, 784)).astype("float32")}
def init_attrs(self):
self.attrs = {"shape": [10000, 784], "low": -5, "high": 10}
self.output_hist = output_hist
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.1), "hist: " + str(hist))
class TestRandintOpError(unittest.TestCase):
def test_errors(self):
main_prog = Program()
start_prog = Program()
with program_guard(main_prog, start_prog):
def test_shape():
shape = np.array([2, 3])
paddle.randint(5, shape=shape, dtype='int32')
self.assertRaises(TypeError, test_shape)
def test_dtype():
paddle.randint(5, shape=[32, 32], dtype='float32')
self.assertRaises(TypeError, test_dtype)
def test_low_high():
paddle.randint(low=5, high=5, shape=[32, 32], dtype='int32')
self.assertRaises(ValueError, test_low_high)
class TestRandintOp_attr_tensorlist(OpTest):
def setUp(self):
self.op_type = "randint"
self.new_shape = (10000, 784)
shape_tensor = []
for index, ele in enumerate(self.new_shape):
shape_tensor.append(("x" + str(index), np.ones(
(1)).astype("int64") * ele))
self.inputs = {'ShapeTensorList': shape_tensor}
self.init_attrs()
self.outputs = {"Out": np.zeros((10000, 784)).astype("int32")}
def init_attrs(self):
self.attrs = {"low": -5, "high": 10}
self.output_hist = output_hist
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.1), "hist: " + str(hist))
class TestRandint_attr_tensor(OpTest):
def setUp(self):
self.op_type = "randint"
self.inputs = {"ShapeTensor": np.array([10000, 784]).astype("int64")}
self.init_attrs()
self.outputs = {"Out": np.zeros((10000, 784)).astype("int64")}
def init_attrs(self):
self.attrs = {"low": -5, "high": 10}
self.output_hist = output_hist
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
hist, prob = self.output_hist(np.array(outs[0]))
self.assertTrue(
np.allclose(
hist, prob, rtol=0, atol=0.1), "hist: " + str(hist))
# Test python API
class TestRandintAPI(unittest.TestCase):
def test_api(self):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
# results are from [0, 5).
output1 = paddle.randint(5)
# shape is a list and dtype is 'int32'
output2 = paddle.randint(
low=-100, high=100, shape=[64, 64], dtype='int32')
# shape is a tuple and dtype is 'int64'
output3 = paddle.randint(
low=-100, high=100, shape=(32, 32, 3), dtype='int64')
# shape is a tensorlist and dtype is 'float32'
dim_1 = fluid.layers.fill_constant([1], "int64", 32)
dim_2 = fluid.layers.fill_constant([1], "int32", 50)
output4 = paddle.randint(
low=-100, high=100, shape=[dim_1, 5], dtype='int32')
# shape is a tensor and dtype is 'float64'
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
output5 = paddle.randint(
low=1, high=1000, shape=var_shape, dtype='int64')
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_program)
outs = exe.run(
train_program,
feed={'var_shape': np.array([100, 100]).astype('int64')},
fetch_list=[output1, output2, output3, output4, output5])
class TestRandintDygraphMode(unittest.TestCase):
def test_check_output(self):
with fluid.dygraph.guard():
x = paddle.randint(10, shape=[10], dtype="int32")
x_np = x.numpy()
for i in range(10):
self.assertTrue((x_np[i] >= 0 and x_np[i] < 10))
if __name__ == "__main__":
unittest.main()
......@@ -64,7 +64,7 @@ from .logic import elementwise_equal #DEFINE_ALIAS
# from .random import shuffle #DEFINE_ALIAS
# from .random import randn #DEFINE_ALIAS
# from .random import rand #DEFINE_ALIAS
# from .random import randint #DEFINE_ALIAS
from .random import randint #DEFINE_ALIAS
from .random import randperm
# from .math import abs #DEFINE_ALIAS
# from .math import acos #DEFINE_ALIAS
......
......@@ -13,6 +13,7 @@
# limitations under the License.
# TODO: define random functions
# __all__ = ['gaussin',
# 'uniform',
# 'shuffle',
......@@ -21,12 +22,173 @@
# 'randint']
from ..fluid import core
from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator
from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator, Variable
from ..fluid.layers.layer_function_generator import templatedoc
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
from ..fluid.layers import utils
from ..fluid.layers.tensor import fill_constant
__all__ = ['randperm', 'randint']
def randint(low,
high=None,
shape=None,
out=None,
dtype=None,
device=None,
stop_gradient=False,
name=None):
"""
This function returns a Tensor filled with random integers from the "discrete uniform" distribution of the
specified data type in the interval [low, high). If high is None (the default), then results are from [0, low).
Args:
low (int): The lower bound on the range of random values to generate, the low is included in the range.
(unless high=None, in which case this parameter is one above the highest such integer).
high (int, optional): The upper bound on the range of random values to generate, the high is excluded
in the range. Default None(see above for behavior if high=None).
shape (list|tuple|Variable, optional): The shape of the output Tensor, if the shape is a list or tuple,
its elements can be an integer
or a Tensor with the shape [1], and the type of the Tensor must be int32 or int64.
If the shape is a Variable, it is a 1-D Tensor, and the type of the Tensor must be
int32 or int64. Default is None, in which case the shape is [1].
out(Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
if out is None, a new Varibale will be create to store the result.
dtype(np.dtype|core.VarDesc.VarType|str, optional): Data type of the output Tensor
which can be int32, int64, if dytpe is `None`, the data
type of created Tensor is `int64`
device(str, optional): This parameter specifies that the Tensor is created
on the GPU or CPU.
stop_gradient(bool, optional): Indicating if we stop gradient from current(out) Variable,
default value is False.
name(str, optional): The default value is None. Normally there is no need for user to set this
property. For more information, please refer to :ref:`api_guide_Name`.
Returns:
Variable: A Tensor of the specified shape filled with random integers.
Raises:
TypeError: Randint's low must less then high.
Examples:
.. code-block:: python
import paddle
import paddle.tensor as tensor
# example 1:
# attr shape is a list which doesn't contain tensor Variable.
result_1 = paddle.randint(low=-5, high=5, shape=[3, 4], dtype="int64")
# example 2:
# attr shape is a list which contains tensor Variable.
dim_1 = fluid.layers.fill_constant([1],"int64",3)
dim_2 = fluid.layers.fill_constant([1],"int32",5)
result_2 = paddle.randint(low=-5, high=5, shape=[dim_1, dim_2], dtype="int32")
# example 3:
# attr shape is a Variable, the data type must be int64 or int32.
var_shape = fluid.data(name='var_shape', shape=[2], dtype="int64")
result_3 = padddle.randint(low=-5, high=5, shape=var_shape, dtype="int32")
var_shape_int32 = fluid.data(name='var_shape_int32', shape=[2], dtype="int32")
result_4 = paddle.randint(low=-5, high=5, shape=var_shape_int32, dtype="int64")
# example 4:
# Input only one parameter
# low=0, high=10, shape=[1], dtype='int64'
result_4 = paddle.randint(10)
"""
__all__ = ['randperm']
def get_new_shape_tensor(list_shape):
new_shape_tensor = []
for dim in list_shape:
if isinstance(dim, Variable):
dim.stop_gradient = True
new_shape_tensor.append(dim)
else:
assert isinstance(dim, int) or isinstance(dim, long)
temp_out = helper.create_variable_for_type_inference('int64')
fill_constant([1], 'int64', dim, force_cpu=True, out=temp_out)
new_shape_tensor.append(temp_out)
return new_shape_tensor
def get_attr_shape(list_shape):
unk_dim_idx = -1
attrs_shape = []
for dim_idx, dim_size in enumerate(list_shape):
if isinstance(dim_size, Variable):
attrs_shape.append(-1)
else:
attrs_shape.append(dim_size)
assert dim_size > 0, (
"Each dimension size given in shape must not be negative "
"except one unknown dimension.")
return attrs_shape
if dtype is None:
dtype = 'int64'
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'randint')
inputs = dict()
attrs = dict()
if shape is None:
shape = [1]
assert len(shape) > 0, ("The size of argument(shape) can't be zero.")
helper = LayerHelper("randint", **locals())
if in_dygraph_mode():
attrs['shape'] = shape
else:
if isinstance(shape, Variable):
shape.stop_gradient = True
inputs["ShapeTensor"] = shape
elif isinstance(shape, (list, tuple)):
assert len(shape) > 0, (
"The size of argument(shape) can't be zero.")
if utils._contain_var(shape):
inputs['ShapeTensorList'] = get_new_shape_tensor(shape)
else:
attrs["shape"] = get_attr_shape(shape)
check_type(shape, 'shape', (list, tuple, Variable), 'randint')
if high is None:
high = low
low = 0
attrs['low'] = low
attrs['high'] = high
if (low >= high):
raise ValueError(
"randint's low must less then high, but received low = {0}, "
"high = {1}".format(low, high))
if out is None:
if name is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
else:
out = helper.create_variable(
name=name, dtype=dtype, persistable=False)
else:
check_dtype(dtype, 'dtype',
convert_dtype(out.dtype), 'randint',
"(The dtype in randint must be the same with out's dtype.)")
attrs['dtype'] = out.dtype
out.stop_gradient = stop_gradient
if device is None:
helper.append_op(
type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs)
else:
with device_guard(device):
helper.append_op(
type='randint',
inputs=inputs,
outputs={'Out': out},
attrs=attrs)
return out
@templatedoc()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册