未验证 提交 97cebfa4 编写于 作者: Z Zhang Ting 提交者: GitHub

add dtype for unique (#26655)

* update doc, test=document_fix

* add attr(dtype)

* refine code
上级 07e3b9a3
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unique_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace operators {
......@@ -149,3 +150,34 @@ REGISTER_OP_CPU_KERNEL(
ops::UniqueKernel<paddle::platform::CPUDeviceContext, double>,
ops::UniqueKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::UniqueKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_VERSION(unique)
.AddCheckpoint(
R"ROC(
Upgrade unique, add 2 outputs [Indices, Counts] and 5 attribute
[return_index, return_inverse, return_counts, axis, is_sorted].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewOutput("Indices",
"The indices of the input tensor that result in the "
"unique tensor.")
.NewOutput("Counts", "The counts for each unique element.")
.NewAttr("return_index",
"If True, also return the indices of the input"
" tensor that result in the unique Tensor.",
false)
.NewAttr("return_inverse",
"If True, also return the indices for where elements"
" in the original input ended up in the returned unique "
"tensor.",
false)
.NewAttr("return_counts",
"If True, also return the counts for each unique element.",
false)
.NewAttr("axis",
"The axis to apply unique. If None, the input will be "
"flattened.",
{})
.NewAttr("is_sorted",
"If True, the unique elements of X are in ascending order."
"Otherwise, the unique elements are not sorted.",
false));
......@@ -131,22 +131,22 @@ static bool Equal(const framework::Tensor& a, const framework::Tensor& b) {
return true;
}
template <typename T>
template <typename InT, typename IndexT>
static void UniqueFlattendTensor(const framework::ExecutionContext& context,
const framework::Tensor& in,
framework::Tensor* out, bool return_index,
bool return_inverse, bool return_counts) {
const T* in_data = in.data<T>();
std::set<T> unique(in_data, in_data + in.numel());
const InT* in_data = in.data<InT>();
std::set<InT> unique(in_data, in_data + in.numel());
out->Resize(framework::make_ddim({static_cast<int64_t>(unique.size())}));
auto out_data = out->mutable_data<T>(context.GetPlace());
auto out_data = out->mutable_data<InT>(context.GetPlace());
std::copy(unique.begin(), unique.end(), out_data);
if (return_index) {
auto* indices = context.Output<framework::Tensor>("Indices");
indices->Resize(framework::make_ddim({out->numel()}));
auto indices_data = indices->mutable_data<int64_t>(context.GetPlace());
std::unordered_map<T, int64_t> indices_map;
auto indices_data = indices->mutable_data<IndexT>(context.GetPlace());
std::unordered_map<InT, IndexT> indices_map;
indices_map.reserve(out->numel());
for (int64_t i = 0; i < in.numel(); ++i) {
if (indices_map.find(in_data[i]) != indices_map.end()) continue;
......@@ -160,8 +160,8 @@ static void UniqueFlattendTensor(const framework::ExecutionContext& context,
if (return_inverse) {
auto* inverse = context.Output<framework::Tensor>("Index");
inverse->Resize(framework::make_ddim({in.numel()}));
auto inverse_data = inverse->mutable_data<int64_t>(context.GetPlace());
std::unordered_map<T, int64_t> inverse_map;
auto inverse_data = inverse->mutable_data<IndexT>(context.GetPlace());
std::unordered_map<InT, IndexT> inverse_map;
inverse_map.reserve(out->numel());
for (int64_t i = 0; i < out->numel(); ++i) {
inverse_map[out_data[i]] = i;
......@@ -174,8 +174,8 @@ static void UniqueFlattendTensor(const framework::ExecutionContext& context,
if (return_counts) {
auto* count = context.Output<framework::Tensor>("Counts");
count->Resize(framework::make_ddim({out->numel()}));
auto count_data = count->mutable_data<int64_t>(context.GetPlace());
std::unordered_map<T, int64_t> counts_map;
auto count_data = count->mutable_data<IndexT>(context.GetPlace());
std::unordered_map<InT, IndexT> counts_map;
counts_map.reserve(out->numel());
for (int64_t i = 0; i < out->numel(); ++i) {
counts_map[out_data[i]] = 0;
......@@ -189,13 +189,13 @@ static void UniqueFlattendTensor(const framework::ExecutionContext& context,
}
}
template <class ForwardIt, typename T>
template <class ForwardIt, typename InT, typename IndexT>
static ForwardIt UniqueDimImpl(const framework::ExecutionContext& context,
ForwardIt first, ForwardIt last,
const std::vector<int64_t>& sorted_indices_vec,
std::vector<int64_t>* inverse_vec,
std::vector<int64_t>* counts_vec,
std::vector<int64_t>* indices_vec) {
const std::vector<IndexT>& sorted_indices_vec,
std::vector<IndexT>* inverse_vec,
std::vector<IndexT>* counts_vec,
std::vector<IndexT>* indices_vec) {
if (first == last) {
return last;
}
......@@ -210,7 +210,7 @@ static ForwardIt UniqueDimImpl(const framework::ExecutionContext& context,
while (++first != last) {
int64_t idx_first = std::distance(begin, first);
int64_t idx_result = std::distance(begin, result);
if (!Equal<T>(*result, *first)) {
if (!Equal<InT>(*result, *first)) {
if (++result != first) {
*result = std::move(*first);
}
......@@ -223,7 +223,7 @@ static ForwardIt UniqueDimImpl(const framework::ExecutionContext& context,
return ++result;
}
template <typename DeviceContext, typename T>
template <typename DeviceContext, typename InT, typename IndexT>
static void UniqueDim(const framework::ExecutionContext& context,
const framework::Tensor& in, framework::Tensor* out,
bool return_index, bool return_inverse,
......@@ -239,25 +239,25 @@ static void UniqueDim(const framework::ExecutionContext& context,
framework::Tensor in_trans;
framework::DDim in_trans_dims = framework::make_ddim(in_trans_dims_vec);
in_trans.Resize(in_trans_dims);
in_trans.mutable_data<T>(context.GetPlace());
in_trans.mutable_data<InT>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
TransCompute<DeviceContext, T>(in.dims().size(), dev_ctx, in, &in_trans,
permute);
TransCompute<DeviceContext, InT>(in.dims().size(), dev_ctx, in, &in_trans,
permute);
// reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2]
framework::DDim in_trans_flat_dims =
framework::flatten_to_2d(in_trans_dims, 1);
in_trans.Resize(in_trans_flat_dims);
// sort indices
std::vector<int64_t> sorted_indices_vec(in_trans.dims()[0]);
std::vector<IndexT> sorted_indices_vec(in_trans.dims()[0]);
std::iota(sorted_indices_vec.begin(), sorted_indices_vec.end(), 0);
int64_t col = in_trans.dims()[1];
const T* in_trans_data = in_trans.data<T>();
const InT* in_trans_data = in_trans.data<InT>();
std::sort(sorted_indices_vec.begin(), sorted_indices_vec.end(),
[&](int64_t a, int64_t b) -> bool {
for (int64_t i = 0; i < col; ++i) {
T lhs = in_trans_data[i + a * col];
T rhs = in_trans_data[i + b * col];
InT lhs = in_trans_data[i + a * col];
InT rhs = in_trans_data[i + b * col];
if (lhs < rhs) {
return true;
} else if (lhs > rhs) {
......@@ -270,18 +270,19 @@ static void UniqueDim(const framework::ExecutionContext& context,
// sort tensor according to indices
framework::Tensor input_sorted;
input_sorted.Resize(in_trans_dims);
input_sorted.mutable_data<T>(context.GetPlace());
T* input_sorted_data = input_sorted.data<T>();
input_sorted.mutable_data<InT>(context.GetPlace());
InT* input_sorted_data = input_sorted.data<InT>();
for (size_t i = 0; i < sorted_indices_vec.size(); ++i) {
memcpy(input_sorted_data + i * col,
in_trans_data + sorted_indices_vec[i] * col, col * sizeof(T));
in_trans_data + static_cast<int64_t>(sorted_indices_vec[i]) * col,
col * sizeof(InT));
}
std::vector<framework::Tensor> input_unbind = Unbind(input_sorted);
std::vector<int64_t> inverse_vec(sorted_indices_vec.size(), 0);
std::vector<int64_t> counts_vec(sorted_indices_vec.size(), 0);
std::vector<int64_t> indices_vec(sorted_indices_vec.size(), 0);
auto last = UniqueDimImpl<std::vector<framework::Tensor>::iterator, T>(
std::vector<IndexT> inverse_vec(sorted_indices_vec.size(), 0);
std::vector<IndexT> counts_vec(sorted_indices_vec.size(), 0);
std::vector<IndexT> indices_vec(sorted_indices_vec.size(), 0);
auto last = UniqueDimImpl<std::vector<framework::Tensor>::iterator, InT>(
context, input_unbind.begin(), input_unbind.end(), sorted_indices_vec,
&inverse_vec, &counts_vec, &indices_vec);
input_unbind.erase(last, input_unbind.end());
......@@ -289,18 +290,18 @@ static void UniqueDim(const framework::ExecutionContext& context,
indices_vec.erase(indices_vec.begin() + input_unbind.size(),
indices_vec.end());
math::ConcatFunctor<DeviceContext, T> concat_functor;
math::ConcatFunctor<DeviceContext, InT> concat_functor;
framework::Tensor out_trans;
std::vector<int64_t> out_trans_dims_vec = in_trans_dims_vec;
out_trans_dims_vec[0] = input_unbind.size();
out_trans.Resize(framework::make_ddim(out_trans_dims_vec));
out_trans.mutable_data<T>(context.GetPlace());
out_trans.mutable_data<InT>(context.GetPlace());
std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]);
out->Resize(framework::make_ddim(out_trans_dims_vec));
out->mutable_data<T>(context.GetPlace());
out->mutable_data<InT>(context.GetPlace());
concat_functor(dev_ctx, input_unbind, 0, &out_trans);
TransCompute<DeviceContext, T>(out_trans.dims().size(), dev_ctx, out_trans,
out, permute);
TransCompute<DeviceContext, InT>(out_trans.dims().size(), dev_ctx, out_trans,
out, permute);
if (return_inverse) {
auto* inverse = context.Output<framework::Tensor>("Index");
......@@ -318,15 +319,80 @@ static void UniqueDim(const framework::ExecutionContext& context,
}
}
template <typename DeviceContext, typename InT>
struct UniqueFlattendTensorFunctor {
const framework::ExecutionContext& ctx_;
const framework::Tensor& in_;
framework::Tensor* out_;
const bool return_index_;
const bool return_inverse_;
const bool return_counts_;
UniqueFlattendTensorFunctor(const framework::ExecutionContext& context,
const framework::Tensor& in,
framework::Tensor* out, bool return_index,
bool return_inverse, bool return_counts)
: ctx_(context),
in_(in),
out_(out),
return_index_(return_index),
return_inverse_(return_inverse),
return_counts_(return_counts) {}
template <typename IndexT>
void apply() const {
UniqueFlattendTensor<InT, IndexT>(ctx_, in_, out_, return_index_,
return_inverse_, return_counts_);
}
};
template <typename DeviceContext, typename InT>
struct UniqueDimFunctor {
const framework::ExecutionContext& ctx_;
const framework::Tensor& in_;
framework::Tensor* out_;
const int axis_;
const bool return_index_;
const bool return_inverse_;
const bool return_counts_;
UniqueDimFunctor(const framework::ExecutionContext& context,
const framework::Tensor& in, framework::Tensor* out,
const int axis, bool return_index, bool return_inverse,
bool return_counts)
: ctx_(context),
in_(in),
out_(out),
axis_(axis),
return_index_(return_index),
return_inverse_(return_inverse),
return_counts_(return_counts) {}
template <typename IndexT>
void apply() const {
UniqueDim<DeviceContext, InT, IndexT>(
ctx_, in_, out_, return_index_, return_inverse_, return_counts_, axis_);
}
};
template <typename DeviceContext, typename T>
class UniqueKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto data_type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
if (data_type == framework::proto::VarType::INT32) {
PADDLE_ENFORCE_LE(
x->numel(), INT_MAX,
platform::errors::InvalidArgument(
"The number of elements in Input(X) should be less than or "
"equal to INT_MAX, but received num is %d. Please set `dtype` to "
"int64.",
x->numel()));
}
if (!context.Attr<bool>("is_sorted")) {
auto data_type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
auto* index = context.Output<framework::Tensor>("Index");
framework::VisitDataType(data_type, UniqueOpFunctor<T>(out, index, x));
......@@ -339,12 +405,16 @@ class UniqueKernel : public framework::OpKernel<T> {
bool return_counts = context.Attr<bool>("return_counts");
if (axis_vec.empty()) {
UniqueFlattendTensor<T>(context, *x, out, return_index, return_inverse,
return_counts);
framework::VisitDataTypeSmall(
data_type,
UniqueFlattendTensorFunctor<DeviceContext, T>(
context, *x, out, return_index, return_inverse, return_counts));
} else {
int axis = axis_vec[0];
UniqueDim<DeviceContext, T>(context, *x, out, return_index,
return_inverse, return_counts, axis);
framework::VisitDataTypeSmall(
data_type, UniqueDimFunctor<DeviceContext, T>(
context, *x, out, axis, return_index, return_inverse,
return_counts));
}
}
};
......
......@@ -14098,17 +14098,11 @@ def sign(x):
def unique(x, dtype='int32'):
"""
:alias_main: paddle.unique
:alias: paddle.unique,paddle.tensor.unique,paddle.tensor.manipulation.unique
:old_api: paddle.fluid.layers.unique
**unique**
Return a unique tensor for `x` and an index tensor pointing to this unique tensor.
Args:
x(Variable): A 1-D input tensor.
dtype(np.dtype|core.VarDesc.VarType|str): The type of index tensor: int32, int64.
x(Tensor): A 1-D input tensor, it's data type should be float32, float64, int32, int64.
dtype(np.dtype|str, optional): The type of index tensor: int32, int64. Default: int32.
Returns:
tuple: (out, index). `out` is the unique tensor for `x`, with identical dtype to `x`, and \
......
......@@ -233,6 +233,24 @@ class TestUniqueAPI(unittest.TestCase):
self.assertTrue((counts.numpy() == np_counts).all(), True)
paddle.enable_static()
def test_dygraph_attr_dtype(self):
paddle.disable_static()
x_data = x_data = np.random.randint(0, 10, (120))
x = paddle.to_tensor(x_data)
out, indices, inverse, counts = paddle.unique(
x,
return_index=True,
return_inverse=True,
return_counts=True,
dtype="int32")
expected_out, np_indices, np_inverse, np_counts = np.unique(
x_data, return_index=True, return_inverse=True, return_counts=True)
self.assertTrue((out.numpy() == expected_out).all(), True)
self.assertTrue((indices.numpy() == np_indices).all(), True)
self.assertTrue((inverse.numpy() == np_inverse).all(), True)
self.assertTrue((counts.numpy() == np_counts).all(), True)
paddle.enable_static()
def test_static_graph(self):
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
......@@ -282,6 +300,9 @@ class TestUniqueError(unittest.TestCase):
def test_axis():
result = paddle.unique(x, axis='12')
def test_dtype():
result = paddle.unique(x, dtype='float64')
self.assertRaises(TypeError, test_axis)
......
......@@ -612,6 +612,7 @@ def unique(x,
return_inverse=False,
return_counts=False,
axis=None,
dtype="int64",
name=None):
"""
Returns the unique elements of `x` in ascending order.
......@@ -625,6 +626,8 @@ def unique(x,
return_counts(bool, optional): If True, also return the counts for each unique element.
axis(int, optional): The axis to apply unique. If None, the input will be flattened.
Default: None.
dtype(np.dtype|str, optional): The date type of `indices` or `inverse` tensor: int32 or int64.
Default: int64.
name(str, optional): Name for the operation. For more information, please refer to
:ref:`api_guide_Name`. Default: None.
......@@ -650,6 +653,7 @@ def unique(x,
np_counts = counts.numpy() # [1 1 3 1]
x_data = np.array([[2, 1, 3], [3, 0, 1], [2, 1, 3]])
x = paddle.to_tensor(x_data)
unique = paddle.unique(x)
np_unique = unique.numpy() # [0 1 2 3]
......@@ -662,11 +666,10 @@ def unique(x,
axis = []
else:
axis = [axis]
attr_dtype = convert_np_dtype_to_dtype_(dtype)
if in_dygraph_mode():
out, inverse, indices, counts = core.ops.unique(
x, 'dtype',
convert_np_dtype_to_dtype_('int32'), 'return_index', return_index,
x, 'dtype', attr_dtype, 'return_index', return_index,
'return_inverse', return_inverse, 'return_counts', return_counts,
'axis', axis, "is_sorted", True)
outs = [out]
......@@ -687,12 +690,13 @@ def unique(x,
check_type(return_index, 'return_index', bool, 'unique')
check_type(return_inverse, 'return_inverse', bool, 'unique')
check_type(return_counts, 'return_counts', bool, 'unique')
check_dtype(dtype, 'dtype', ['int32', 'int64'], 'unique')
if len(axis) != 0:
check_type(axis[0], 'axis', int, 'unique')
helper = LayerHelper('unique', **locals())
attrs = {
'dtype': int(core.VarDesc.VarType.INT32),
'dtype': attr_dtype,
"return_index": return_index,
"return_inverse": return_inverse,
"return_counts": return_counts,
......@@ -702,19 +706,19 @@ def unique(x,
out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
inverse = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.INT64, stop_gradient=True)
dtype=attr_dtype, stop_gradient=True)
outputs = {"Out": out, "Index": inverse}
outs = [out]
if return_index:
indices = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.INT64, stop_gradient=True)
dtype=attr_dtype, stop_gradient=True)
outputs["Indices"] = indices
outs.append(indices)
if return_inverse:
outs.append(inverse)
if return_counts:
counts = helper.create_variable_for_type_inference(
dtype=core.VarDesc.VarType.INT64, stop_gradient=True)
dtype=attr_dtype, stop_gradient=True)
outputs["Counts"] = counts
outs.append(counts)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册