未验证 提交 9297f49e 编写于 作者: C cc 提交者: GitHub

[OP] Add randperm op (#23292)

上级 08e3d9c0
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/randperm_op.h"
#include <string>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class RandpermOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"The output(Out) of randperm op must not be null."));
int n = ctx->Attrs().Get<int>("n");
PADDLE_ENFORCE_GT(
n, 0, platform::errors::InvalidArgument(
"The input(n) of randperm op must be greater than 0."));
ctx->SetOutputDim("Out", framework::make_ddim({n}));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto data_type =
static_cast<framework::proto::VarType::Type>(ctx.Attr<int>("dtype"));
return framework::OpKernelType(data_type, ctx.GetPlace());
}
};
class RandpermOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddOutput("Out", "The output tensor of randperm op.");
AddAttr<int>(
"n", "The upper bound (exclusive), and it should be greater than 0.");
AddAttr<int>("dtype",
"The data type of output tensor. "
"Default: 3[int64].")
.SetDefault(framework::proto::VarType::INT64);
AddAttr<int>("seed",
"Random seed used for permute samples. "
"0 means use a seed generated by the system."
"Note that if seed is not 0, this operator will always "
"generate the same random permutation every time. "
"Default: 0.")
.SetDefault(0);
AddComment(R"DOC(
This operator returns a random permutation of integers from 0 to n-1.
)DOC");
}
};
class RandpermOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto var_data_type = static_cast<framework::proto::VarType::Type>(
boost::get<int>(ctx->GetAttr("dtype")));
auto out_var_name = ctx->Output("Out").front();
ctx->SetDataType(out_var_name, var_data_type);
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(
randperm, paddle::operators::RandpermOp, paddle::operators::RandpermOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
paddle::operators::RandpermOpVarTypeInference);
template <typename T>
using kernel =
paddle::operators::RandpermKernel<paddle::platform::CPUDeviceContext, T>;
REGISTER_OP_CPU_KERNEL(randperm, kernel<int64_t>, kernel<int>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/randperm_op.h"
template <typename T>
using kernel =
paddle::operators::RandpermKernel<paddle::platform::CUDADeviceContext, T>;
REGISTER_OP_CUDA_KERNEL(randperm, kernel<int64_t>, kernel<int>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <cstdlib>
#include <ctime>
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace operators {
template <typename T>
static inline void random_permate(T* data_ptr, int num, unsigned int seed) {
for (int i = 0; i < num; ++i) {
data_ptr[i] = static_cast<T>(i);
}
if (seed == 0) {
seed = std::random_device()();
}
std::srand(seed);
std::random_shuffle(data_ptr, data_ptr + num);
}
template <typename DeviceContext, typename T>
class RandpermKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
int n = ctx.Attr<int>("n");
unsigned int seed = static_cast<unsigned int>(ctx.Attr<int>("seed"));
framework::Variable* out_var = ctx.OutputVar("Out");
framework::Tensor* out_tensor =
framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var);
if (platform::is_cpu_place(ctx.GetPlace())) {
T* out_data = out_tensor->mutable_data<T>(platform::CPUPlace());
random_permate<T>(out_data, n, seed);
} else {
framework::Tensor tmp_tensor;
tmp_tensor.Resize(framework::make_ddim({n}));
T* tmp_data = tmp_tensor.mutable_data<T>(platform::CPUPlace());
random_permate<T>(tmp_data, n, seed);
framework::TensorCopy(tmp_tensor, platform::CUDAPlace(), out_tensor);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -87,7 +87,7 @@ from .tensor.logic import elementwise_equal #DEFINE_ALIAS
# from .tensor.random import uniform #DEFINE_ALIAS
# from .tensor.random import shuffle #DEFINE_ALIAS
# from .tensor.random import randn #DEFINE_ALIAS
# from .tensor.random import randperm #DEFINE_ALIAS
from .tensor.random import randperm
# from .tensor.random import rand #DEFINE_ALIAS
# from .tensor.random import randint #DEFINE_ALIAS
# from .tensor.math import abs #DEFINE_ALIAS
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.op import Operator
from paddle.fluid import Program, program_guard
def check_randperm_out(n, data_np):
assert isinstance(data_np, np.ndarray), \
"The input data_np should be np.ndarray."
gt_sorted = np.arange(n)
out_sorted = np.sort(data_np)
return list(gt_sorted == out_sorted)
def error_msg(data_np):
return "The sorted ground truth and sorted out should " + \
"be equal, out = " + str(data_np)
def convert_dtype(dtype_str):
dtype_str_list = ["int32", "int64"]
dtype_num_list = [2, 3]
assert dtype_str in dtype_str_list, dtype_str + \
" should in " + str(dtype_str_list)
return dtype_num_list[dtype_str_list.index(dtype_str)]
class TestRandpermOp(OpTest):
""" Test randperm op."""
def setUp(self):
self.op_type = "randperm"
self.n = 200
self.dtype = "int64"
self.device = None
self.seed = 0
self.inputs = {}
self.outputs = {"Out": np.zeros((self.n)).astype(self.dtype)}
self.init_attrs()
self.attrs = {
"n": self.n,
"dtype": convert_dtype(self.dtype),
"device": self.device,
"seed": self.seed,
}
def init_attrs(self):
pass
def test_check_output(self):
self.check_output_customized(self.verify_output)
def verify_output(self, outs):
out_np = np.array(outs[0])
self.assertTrue(
check_randperm_out(self.n, out_np), msg=error_msg(out_np))
class TestRandpermOp_attr_n(TestRandpermOp):
""" Test randperm op for attr n. """
def init_attrs(self):
self.n = 10000
class TestRandpermOp_attr_int32(TestRandpermOp):
""" Test randperm op for attr int32 dtype. """
def init_attrs(self):
self.dtype = "int32"
class TestRandpermOp_attr_device_cpu(TestRandpermOp):
""" Test randperm op for cpu device. """
def init_attrs(self):
self.device = "cpu"
class TestRandpermOp_attr_device_gpu(TestRandpermOp):
""" Test randperm op for gpu device. """
def init_attrs(self):
self.device = "gpu"
class TestRandpermOp_attr_seed(TestRandpermOp):
""" Test randperm op for attr seed. """
def init_attrs(self):
self.seed = 10
class TestRandpermOpError(unittest.TestCase):
""" Test randperm op for raise error. """
def test_errors(self):
main_prog = Program()
start_prog = Program()
with program_guard(main_prog, start_prog):
def test_Variable():
out = np.arange(10)
paddle.randperm(n=10, out=out)
self.assertRaises(TypeError, test_Variable)
def test_value():
paddle.randperm(n=-3)
self.assertRaises(ValueError, test_value)
class TestRandpermOp_attr_out(unittest.TestCase):
""" Test randperm op for attr out. """
def test_attr_tensor_API(self):
startup_program = fluid.Program()
train_program = fluid.Program()
with fluid.program_guard(train_program, startup_program):
n = 10
data_1 = fluid.layers.fill_constant([n], "int64", 3)
paddle.randperm(n=n, out=data_1)
data_2 = paddle.randperm(n=n, dtype="int32", device="cpu")
place = fluid.CPUPlace()
if fluid.core.is_compiled_with_cuda():
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
exe.run(startup_program)
outs = exe.run(train_program, fetch_list=[data_1, data_2])
out_np = np.array(outs[0])
self.assertTrue(
check_randperm_out(n, out_np), msg=error_msg(out_np))
class TestRandpermDygraphMode(unittest.TestCase):
def test_check_output(self):
with fluid.dygraph.guard():
n = 10
data_1 = paddle.randperm(n, dtype="int64")
data_1_np = data_1.numpy()
self.assertTrue(
check_randperm_out(n, data_1_np), msg=error_msg(data_1_np))
data_2 = paddle.randperm(n, dtype="int32", device="cpu")
data_2_np = data_2.numpy()
self.assertTrue(
check_randperm_out(n, data_2_np), msg=error_msg(data_2_np))
if __name__ == "__main__":
unittest.main()
......@@ -63,9 +63,9 @@ from .logic import elementwise_equal #DEFINE_ALIAS
# from .random import uniform #DEFINE_ALIAS
# from .random import shuffle #DEFINE_ALIAS
# from .random import randn #DEFINE_ALIAS
# from .random import randperm #DEFINE_ALIAS
# from .random import rand #DEFINE_ALIAS
# from .random import randint #DEFINE_ALIAS
from .random import randperm
# from .math import abs #DEFINE_ALIAS
# from .math import acos #DEFINE_ALIAS
# from .math import asin #DEFINE_ALIAS
......
......@@ -17,6 +17,100 @@
# 'uniform',
# 'shuffle',
# 'randn',
# 'randperm',
# 'rand',
# 'randint']
from ..fluid import core
from ..fluid.framework import device_guard, in_dygraph_mode, _varbase_creator
from ..fluid.layers.layer_function_generator import templatedoc
from ..fluid.layer_helper import LayerHelper
from ..fluid.data_feeder import convert_dtype, check_variable_and_dtype, check_type, check_dtype
__all__ = ['randperm']
@templatedoc()
def randperm(n,
out=None,
dtype="int64",
device=None,
stop_gradient=True,
seed=0):
"""
${comment}
Args:
n (int): The upper bound (exclusive), and it should be greater than 0.
out (Variable, optional): Optional output which can be any created
Variable that meets the requirements to store the result of operation.
If out is None, a new Varibale will be create to store the result.
Default: None.
dtype (np.dtype|core.VarDesc.VarType|str, optional): The type of the
output Tensor. Supported data types: int64, int32. Default: int32.
device (str, optional): Specific the output variable to be saved in cpu
or gpu memory. Supported None, 'cpu', 'gpu'. If it is None, the output
variable will be automatically assigned devices.
Default: None.
stop_gradient (bool, optional): Whether grad should record operations
on the returned tensor. Default: True.
seed (int, optional): Random seed used for permute samples. If seed is
equal to 0, it means use a seed generated by the system. Note that
if seed is not 0, this operator will always generate the same random
permutation every time. Default: 0.
Returns:
${out_comment}.
Return Type:
${out_type}
Examples:
.. code-block:: python
import paddle
import paddle.fluid as fluid
num = 6
is_use_gpu = False
data_1 = paddle.randperm(num)
fluid.layers.Print(data_1)
data_2 = paddle.randperm(num, dtype="int32", seed=1)
fluid.layers.Print(data_2)
data_3 = paddle.randperm(num, stop_gradient=False, device="cpu")
fluid.layers.Print(data_3)
paddle.randperm(num, out=data_3)
fluid.layers.Print(data_3)
place = fluid.CUDAPlace(0) if is_use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
exe.run()
"""
if n < 1:
raise ValueError("The input n should be greater than 0 in randperm op.")
check_dtype(dtype, 'dtype', ['int64', 'int32'], 'randperm')
dtype = convert_dtype(dtype)
if device not in [None, 'cpu', 'gpu']:
raise ValueError("The input device should in [None, 'cpu', 'gpu'].")
check_type(stop_gradient, 'stop_gradient', bool, 'randperm')
helper = LayerHelper("randperm", **locals())
if out is None:
out = helper.create_variable_for_type_inference(dtype=dtype)
else:
check_variable_and_dtype(out, 'out', [dtype], 'randperm')
if stop_gradient:
out.stop_gradient = True
inputs = dict()
outputs = {'Out': [out]}
attrs = {'n': n, 'dtype': out.dtype, 'seed': seed}
with device_guard(device):
helper.append_op(
type='randperm', inputs=inputs, outputs=outputs, attrs=attrs)
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册