未验证 提交 41a09113 编写于 作者: Y yujun 提交者: GitHub

[PaddlePaddle hackathon] Add randint_like (#36169)

* add randint like

* rm .cc .cu

* Update unity_build_rule.cmake

* try to make test pass

* use python

* update

* update randint_like

* rename test_randint_like_op -> test_randint_like

* update

* update randint like docs

* update randint like

* update

* update

* add bool

* update randint like test

* update

* update
上级 e4a134ac
......@@ -230,6 +230,7 @@ from .tensor.random import uniform # noqa: F401
from .tensor.random import randn # noqa: F401
from .tensor.random import rand # noqa: F401
from .tensor.random import randint # noqa: F401
from .tensor.random import randint_like # noqa: F401
from .tensor.random import randperm # noqa: F401
from .tensor.search import argmax # noqa: F401
from .tensor.search import argmin # noqa: F401
......@@ -376,6 +377,7 @@ __all__ = [ # noqa
'ParamAttr',
'stanh',
'randint',
'randint_like',
'assign',
'gather',
'scale',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
from paddle.static import program_guard, Program
# Test python API
class TestRandintLikeAPI(unittest.TestCase):
def setUp(self):
self.x_bool = np.zeros((10, 12)).astype("bool")
self.x_int32 = np.zeros((10, 12)).astype("int32")
self.x_int64 = np.zeros((10, 12)).astype("int64")
self.x_float16 = np.zeros((10, 12)).astype("float16")
self.x_float32 = np.zeros((10, 12)).astype("float32")
self.x_float64 = np.zeros((10, 12)).astype("float64")
self.dtype = ["bool", "int32", "int64", "float16", "float32", "float64"]
self.place=paddle.CUDAPlace(0) if paddle.is_compiled_with_cuda() \
else paddle.CPUPlace()
def test_static_api(self):
paddle.enable_static()
with program_guard(Program(), Program()):
# results are from [-100, 100).
x_bool = paddle.fluid.data(
name="x_bool", shape=[10, 12], dtype="bool")
x_int32 = paddle.fluid.data(
name="x_int32", shape=[10, 12], dtype="int32")
x_int64 = paddle.fluid.data(
name="x_int64", shape=[10, 12], dtype="int64")
x_float16 = paddle.fluid.data(
name="x_float16", shape=[10, 12], dtype="float16")
x_float32 = paddle.fluid.data(
name="x_float32", shape=[10, 12], dtype="float32")
x_float64 = paddle.fluid.data(
name="x_float64", shape=[10, 12], dtype="float64")
exe = paddle.static.Executor(self.place)
# x dtype is bool output dtype in ["bool", "int32", "int64", "float16", "float32", "float64"]
outlist1 = [
paddle.randint_like(
x_bool, low=-10, high=10, dtype=dtype)
for dtype in self.dtype
]
outs1 = exe.run(feed={'x_bool': self.x_bool}, fetch_list=outlist1)
for out, dtype in zip(outs1, self.dtype):
self.assertTrue(out.dtype, np.dtype(dtype))
self.assertTrue(((out >= -10) & (out <= 10)).all(), True)
# x dtype is int32 output dtype in ["bool", "int32", "int64", "float16", "float32", "float64"]
outlist2 = [
paddle.randint_like(
x_int32, low=-5, high=10, dtype=dtype)
for dtype in self.dtype
]
outs2 = exe.run(feed={'x_int32': self.x_int32}, fetch_list=outlist2)
for out, dtype in zip(outs2, self.dtype):
self.assertTrue(out.dtype, np.dtype(dtype))
self.assertTrue(((out >= -5) & (out <= 10)).all(), True)
# x dtype is int64 output dtype in ["bool", "int32", "int64", "float16", "float32", "float64"]
outlist3 = [
paddle.randint_like(
x_int64, low=-100, high=100, dtype=dtype)
for dtype in self.dtype
]
outs3 = exe.run(feed={'x_int64': self.x_int64}, fetch_list=outlist3)
for out, dtype in zip(outs3, self.dtype):
self.assertTrue(out.dtype, np.dtype(dtype))
self.assertTrue(((out >= -100) & (out <= 100)).all(), True)
# x dtype is float16 output dtype in ["bool", "int32", "int64", "float16", "float32", "float64"]
outlist4 = [
paddle.randint_like(
x_float16, low=-3, high=25, dtype=dtype)
for dtype in self.dtype
]
outs4 = exe.run(feed={'x_float16': self.x_float16},
fetch_list=outlist4)
for out, dtype in zip(outs4, self.dtype):
self.assertTrue(out.dtype, np.dtype(dtype))
self.assertTrue(((out >= -3) & (out <= 25)).all(), True)
# x dtype is float32 output dtype in ["bool", "int32", "int64", "float16", "float32", "float64"]
outlist5 = [
paddle.randint_like(
x_float32, low=-25, high=25, dtype=dtype)
for dtype in self.dtype
]
outs5 = exe.run(feed={'x_float32': self.x_float32},
fetch_list=outlist5)
for out, dtype in zip(outs5, self.dtype):
self.assertTrue(out.dtype, np.dtype(dtype))
self.assertTrue(((out >= -25) & (out <= 25)).all(), True)
# x dtype is float64 output dtype in ["bool", "int32", "int64", "float16", "float32", "float64"]
outlist6 = [
paddle.randint_like(
x_float64, low=-16, high=16, dtype=dtype)
for dtype in self.dtype
]
outs6 = exe.run(feed={'x_float64': self.x_float64},
fetch_list=outlist6)
for out, dtype in zip(outs6, self.dtype):
self.assertTrue(out.dtype, dtype)
self.assertTrue(((out >= -16) & (out <= 16)).all(), True)
def test_dygraph_api(self):
paddle.disable_static(self.place)
# x dtype ["bool", "int32", "int64", "float16", "float32", "float64"]
for x in [
self.x_bool, self.x_int32, self.x_int64, self.x_float16,
self.x_float32, self.x_float64
]:
x_inputs = paddle.to_tensor(x)
# self.dtype ["bool", "int32", "int64", "float16", "float32", "float64"]
for dtype in self.dtype:
out = paddle.randint_like(
x_inputs, low=-100, high=100, dtype=dtype)
self.assertTrue(out.numpy().dtype, np.dtype(dtype))
self.assertTrue(((out.numpy() >= -100) &
(out.numpy() <= 100)).all(), True)
paddle.enable_static()
def test_errors(self):
paddle.enable_static()
with program_guard(Program(), Program()):
x_bool = paddle.fluid.data(
name="x_bool", shape=[10, 12], dtype="bool")
x_int32 = paddle.fluid.data(
name="x_int32", shape=[10, 12], dtype="int32")
x_int64 = paddle.fluid.data(
name="x_int64", shape=[10, 12], dtype="int64")
x_float16 = paddle.fluid.data(
name="x_float16", shape=[10, 12], dtype="float16")
x_float32 = paddle.fluid.data(
name="x_float32", shape=[10, 12], dtype="float32")
x_float64 = paddle.fluid.data(
name="x_float64", shape=[10, 12], dtype="float64")
# x dtype is bool
# low is 5 and high is 5, low must less then high
self.assertRaises(
ValueError, paddle.randint_like, x_bool, low=5, high=5)
# low(default value) is 0 and high is -5, low must less then high
self.assertRaises(ValueError, paddle.randint_like, x_bool, high=-5)
# if high is None, low must be greater than 0
self.assertRaises(ValueError, paddle.randint_like, x_bool, low=-5)
# x dtype is int32
# low is 5 and high is 5, low must less then high
self.assertRaises(
ValueError, paddle.randint_like, x_int32, low=5, high=5)
# low(default value) is 0 and high is -5, low must less then high
self.assertRaises(ValueError, paddle.randint_like, x_int32, high=-5)
# if high is None, low must be greater than 0
self.assertRaises(ValueError, paddle.randint_like, x_int32, low=-5)
# x dtype is int64
# low is 5 and high is 5, low must less then high
self.assertRaises(
ValueError, paddle.randint_like, x_int64, low=5, high=5)
# low(default value) is 0 and high is -5, low must less then high
self.assertRaises(ValueError, paddle.randint_like, x_int64, high=-5)
# if high is None, low must be greater than 0
self.assertRaises(ValueError, paddle.randint_like, x_int64, low=-5)
# x dtype is float16
# low is 5 and high is 5, low must less then high
self.assertRaises(
ValueError, paddle.randint_like, x_float16, low=5, high=5)
# low(default value) is 0 and high is -5, low must less then high
self.assertRaises(
ValueError, paddle.randint_like, x_float16, high=-5)
# if high is None, low must be greater than 0
self.assertRaises(
ValueError, paddle.randint_like, x_float16, low=-5)
# x dtype is float32
# low is 5 and high is 5, low must less then high
self.assertRaises(
ValueError, paddle.randint_like, x_float32, low=5, high=5)
# low(default value) is 0 and high is -5, low must less then high
self.assertRaises(
ValueError, paddle.randint_like, x_float32, high=-5)
# if high is None, low must be greater than 0
self.assertRaises(
ValueError, paddle.randint_like, x_float32, low=-5)
# x dtype is float64
# low is 5 and high is 5, low must less then high
self.assertRaises(
ValueError, paddle.randint_like, x_float64, low=5, high=5)
# low(default value) is 0 and high is -5, low must less then high
self.assertRaises(
ValueError, paddle.randint_like, x_float64, high=-5)
# if high is None, low must be greater than 0
self.assertRaises(
ValueError, paddle.randint_like, x_float64, low=-5)
if __name__ == "__main__":
unittest.main()
......@@ -197,6 +197,7 @@ from .random import uniform_ # noqa: F401
from .random import randn # noqa: F401
from .random import rand # noqa: F401
from .random import randint # noqa: F401
from .random import randint_like # noqa: F401
from .random import randperm # noqa: F401
from .search import argmax # noqa: F401
from .search import argmin # noqa: F401
......
......@@ -661,6 +661,180 @@ def randint(low=0, high=None, shape=[1], dtype=None, name=None):
return out
def randint_like(x, low=0, high=None, dtype=None, name=None):
"""
This OP returns a Tensor filled with random integers from a discrete uniform
distribution in the range [``low``, ``high``), with the same shape as ``x``.
(use ``dtype`` if ``dtype`` is not None)
If ``high`` is None (the default), the range is [0, ``low``).
Args:
x (Tensor): The input tensor which specifies shape. The dtype of ``x``
can be bool, int32, int64, float16, float32, float64.
low (int): The lower bound on the range of random values to generate.
The ``low`` is included in the range. If ``high`` is None, the
range is [0, ``low``). Default is 0.
high (int, optional): The upper bound on the range of random values to
generate, the ``high`` is excluded in the range. Default is None
(see above for behavior if high = None). Default is None.
dtype (str|np.dtype, optional): The data type of the
output tensor. Supported data types: bool, int32, int64, float16,
float32, float64. If ``dytpe`` is None, the data type is the
same as x's data type. Default is None.
name (str, optional): The default value is None. Normally there is no
need for user to set this property. For more information, please
refer to :ref:`api_guide_Name`.
Returns:
Tensor: A Tensor filled with random integers from a discrete uniform
distribution in the range [``low``, ``high``), with ``shape`` and ``dtype``.
Examples:
.. code-block:: python
import paddle
# example 1:
# dtype is None and the dtype of x is float16
x = paddle.zeros((1,2)).astype("float16")
out1 = paddle.randint_like(x, low=-5, high=5)
print(out1)
print(out1.dtype)
# [[0, -3]] # random
# paddle.float16
# example 2:
# dtype is None and the dtype of x is float32
x = paddle.zeros((1,2)).astype("float32")
out2 = paddle.randint_like(x, low=-5, high=5)
print(out2)
print(out2.dtype)
# [[0, -3]] # random
# paddle.float32
# example 3:
# dtype is None and the dtype of x is float64
x = paddle.zeros((1,2)).astype("float64")
out3 = paddle.randint_like(x, low=-5, high=5)
print(out3)
print(out3.dtype)
# [[0, -3]] # random
# paddle.float64
# example 4:
# dtype is None and the dtype of x is int32
x = paddle.zeros((1,2)).astype("int32")
out4 = paddle.randint_like(x, low=-5, high=5)
print(out4)
print(out4.dtype)
# [[0, -3]] # random
# paddle.int32
# example 5:
# dtype is None and the dtype of x is int64
x = paddle.zeros((1,2)).astype("int64")
out5 = paddle.randint_like(x, low=-5, high=5)
print(out5)
print(out5.dtype)
# [[0, -3]] # random
# paddle.int64
# example 6:
# dtype is float64 and the dtype of x is float32
x = paddle.zeros((1,2)).astype("float32")
out6 = paddle.randint_like(x, low=-5, high=5, dtype="float64")
print(out6)
print(out6.dtype)
# [[0, -1]] # random
# paddle.float64
# example 7:
# dtype is bool and the dtype of x is float32
x = paddle.zeros((1,2)).astype("float32")
out7 = paddle.randint_like(x, low=-5, high=5, dtype="bool")
print(out7)
print(out7.dtype)
# [[0, -1]] # random
# paddle.bool
# example 8:
# dtype is int32 and the dtype of x is float32
x = paddle.zeros((1,2)).astype("float32")
out8 = paddle.randint_like(x, low=-5, high=5, dtype="int32")
print(out8)
print(out8.dtype)
# [[0, -1]] # random
# paddle.int32
# example 9:
# dtype is int64 and the dtype of x is float32
x = paddle.zeros((1,2)).astype("float32")
out9 = paddle.randint_like(x, low=-5, high=5, dtype="int64")
print(out9)
print(out9.dtype)
# [[0, -1]] # random
# paddle.int64
# example 10:
# dtype is int64 and the dtype of x is bool
x = paddle.zeros((1,2)).astype("bool")
out10 = paddle.randint_like(x, low=-5, high=5, dtype="int64")
print(out10)
print(out10.dtype)
# [[0, -1]] # random
# paddle.int64
"""
if high is None:
if low <= 0:
raise ValueError(
"If high is None, low must be greater than 0, but received low = {0}.".
format(low))
high = low
low = 0
if dtype is None:
dtype = x.dtype
if not isinstance(dtype, core.VarDesc.VarType):
dtype = convert_np_dtype_to_dtype_(dtype)
shape = x.shape
if low >= high:
raise ValueError(
"randint_like's low must less then high, but received low = {0}, "
"high = {1}".format(low, high))
if in_dygraph_mode():
shape = utils.convert_shape_to_list(shape)
out = _C_ops.randint('shape', shape, 'low', low, 'high', high, 'seed',
0, 'dtype', core.VarDesc.VarType.INT64)
out = paddle.cast(out, dtype)
return out
check_shape(shape, 'randint_like')
check_dtype(dtype, 'dtype',
['bool', 'float16', 'float32', 'float64', 'int32',
'int64'], 'randint_like')
inputs = dict()
attrs = {
'low': low,
'high': high,
'seed': 0,
'dtype': core.VarDesc.VarType.INT64
}
utils.get_shape_tensor_inputs(
inputs=inputs, attrs=attrs, shape=shape, op_type='randint_like')
helper = LayerHelper("randint", **locals())
out = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.INT64)
helper.append_op(
type='randint', inputs=inputs, outputs={'Out': out}, attrs=attrs)
out.stop_gradient = True
out = paddle.cast(out, dtype)
return out
def randperm(n, dtype="int64", name=None):
"""
This OP returns a 1-D Tensor filled with random permutation values from 0
......
......@@ -417,11 +417,12 @@ TETRAD_PARALLEL_JOB_NEW = [
'test_fusion_group_op', 'test_imperative_layer_apply',
'test_executor_return_tensor_not_overwriting',
'test_optimizer_in_control_flow', 'test_lookup_table_op', 'test_randint_op',
'test_convert_call', 'test_sigmoid_cross_entropy_with_logits_op',
'copy_cross_scope_test', 'test_normalization_wrapper',
'test_pretrained_model', 'test_flip', 'test_cosine_similarity_api',
'test_cumsum_op', 'test_range', 'test_log_loss_op', 'test_where_index',
'test_tril_triu_op', 'test_lod_reset_op', 'test_lod_tensor', 'test_addmm_op',
'test_randint_like', 'test_convert_call',
'test_sigmoid_cross_entropy_with_logits_op', 'copy_cross_scope_test',
'test_normalization_wrapper', 'test_pretrained_model', 'test_flip',
'test_cosine_similarity_api', 'test_cumsum_op', 'test_range',
'test_log_loss_op', 'test_where_index', 'test_tril_triu_op',
'test_lod_reset_op', 'test_lod_tensor', 'test_addmm_op',
'test_index_select_op', 'test_nvprof', 'test_index_sample_op',
'test_unstack_op', 'test_increment', 'strided_memcpy_test',
'test_target_assign_op', 'test_trt_dynamic_shape_transformer_prune',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册