未验证 提交 3ab1866c 编写于 作者: W wawltor 提交者: GitHub

Add the op of unique_with_counts, expand count function of the op unique (#18720)

* test=develop
Add the op of unique_with_counts, the op is calc the unqiue input of data, and output the corresponding indices and count of data.

* test=develop
Check the input and dtype in the op of unique_with_counts

* test=develop
test=document_preview
update the API.spec for `unique_with_counts`, at the same time, optimize the python api in the op of `unique_with_count`

* test=develop
test=document_preview
Fix some python api problem in the op of `unique_with_counts`, and change the error messsage in this op.

* Fix some API problem in the op of `unique_with_counts`
test=develop
test=document_preview

* test=develop
test=document_preview
Fix the api sample of op `unique_with_counts`, and update api.spec
上级 5cf2d385
......@@ -205,6 +205,7 @@ paddle.fluid.layers.pad2d (ArgSpec(args=['input', 'paddings', 'mode', 'pad_value
paddle.fluid.layers.unstack (ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'b0c4ca08d4eb295189e1b107c920d093'))
paddle.fluid.layers.sequence_enumerate (ArgSpec(args=['input', 'win_size', 'pad_value', 'name'], varargs=None, keywords=None, defaults=(0, None)), ('document', 'b870fed41abd2aecf929ece65f555fa1'))
paddle.fluid.layers.unique (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', 'cab0b06e5683875f12f0efc62fa230a9'))
paddle.fluid.layers.unique_with_counts (ArgSpec(args=['x', 'dtype'], varargs=None, keywords=None, defaults=('int32',)), ('document', '1cb59c65b41766116944b8ed1e6ad345'))
paddle.fluid.layers.expand (ArgSpec(args=['x', 'expand_times', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '33bc4f6010282ffe044d77be7ba7c275'))
paddle.fluid.layers.sequence_concat (ArgSpec(args=['input', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', 'b992616c1afbd6b0c2a897ac23036381'))
paddle.fluid.layers.scale (ArgSpec(args=['x', 'scale', 'bias', 'bias_after_scale', 'act', 'name'], varargs=None, keywords=None, defaults=(1.0, 0.0, True, None, None)), ('document', '463e4713806e5adaa4d20a41e2218453'))
......
......@@ -28,10 +28,12 @@ struct UniqueOpFunctor {
framework::Tensor* out_;
framework::Tensor* index_;
const framework::Tensor* in_;
framework::Tensor* count_;
UniqueOpFunctor(framework::Tensor* out, framework::Tensor* index,
const framework::Tensor* in)
: out_(out), index_(index), in_(in) {}
const framework::Tensor* in,
framework::Tensor* count = nullptr)
: out_(out), index_(index), in_(in), count_(count) {}
template <typename IndexT>
void apply() const {
......@@ -50,8 +52,8 @@ struct UniqueOpFunctor {
for (auto i = 0; i < in_->numel(); i++) {
auto it = dict.find(in_data[i]);
if (it == dict.end()) {
dict.insert(std::make_pair(in_data[i], j));
uniq.push_back(in_data[i]);
dict.emplace(std::make_pair(in_data[i], j));
uniq.emplace_back(in_data[i]);
index_data[i] = static_cast<IndexT>(j);
j++;
} else {
......@@ -59,6 +61,37 @@ struct UniqueOpFunctor {
}
}
if (count_ != nullptr) {
// Resize the count tensor dims to allocate the memory
count_->Resize(framework::make_ddim({static_cast<int64_t>(uniq.size())}));
IndexT* count_data = count_->mutable_data<IndexT>(platform::CPUPlace());
// init count_data to 0
memset(count_data, 0, uniq.size() * sizeof(IndexT));
const auto& index_type = index_->type();
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE(
index_type_match,
"Index holds the wrong type, it holds %s, but desires to be %s or %s",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64));
if (index_type == framework::proto::VarType::INT32) {
for (auto i = 0; i < in_->numel(); ++i) {
const IndexT& index = index_data[i];
count_data[static_cast<int32_t>(index)] += static_cast<IndexT>(1);
}
} else {
for (auto i = 0; i < in_->numel(); ++i) {
const IndexT& index = index_data[i];
count_data[static_cast<int64_t>(index)] += static_cast<IndexT>(1);
}
}
}
out_->Resize(framework::make_ddim({static_cast<int64_t>(uniq.size())}));
auto out_data = out_->mutable_data<InT>(platform::CPUPlace());
std::memcpy(out_data, uniq.data(), uniq.size() * sizeof(InT));
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unique_with_counts_op.h"
namespace paddle {
namespace operators {
class UniqueWithCountsOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of UniqueWithCountsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UniqueWithCountsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Index"),
"Output(Index) of UniqueWithCountsOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Count"),
"Output(Count) of UniqueWithCountsOp should not be null.");
auto in_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE(in_dims.size() == 1,
"The op of fluid.layers.unique_with_counts, Input(X) should "
"be a vector.");
ctx->SetOutputDim("Out", {-1});
ctx->SetOutputDim("Index", in_dims);
ctx->SetOutputDim("Count", {-1});
}
};
class UniqueWithCountsOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "Input tensor. It should be a 1-D tensor.");
AddAttr<int>("dtype", "data type for output index");
AddOutput("Out", "A unique subsequence for input tensor.");
AddOutput("Index",
"An index tensor pointing to unique subsequence, which has "
"identical shape with input tensor and the data type is set by "
"the attr `dtype`");
AddOutput("Count", "A subsequence for the count of unique index");
AddComment(R"DOC(
Return a unique subsequence for 1-D input tensor, index tensor pointing to this unique subsequence,
and the subsequence for the count of unique index.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(unique_with_counts, ops::UniqueWithCountsOp,
ops::UniqueWithCountsOpMaker);
REGISTER_OP_CPU_KERNEL(unique_with_counts, ops::UniqueWithCountsKernel<float>,
ops::UniqueWithCountsKernel<double>,
ops::UniqueWithCountsKernel<int32_t>,
ops::UniqueWithCountsKernel<int64_t>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <cmath>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/unique_op.h"
namespace paddle {
namespace operators {
template <typename T>
class UniqueWithCountsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto data_type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto* index = context.Output<framework::Tensor>("Index");
auto* count = context.Output<framework::Tensor>("Count");
framework::VisitDataType(data_type,
UniqueOpFunctor<T>(out, index, x, count));
}
};
} // namespace operators
} // namespace paddle
......@@ -148,6 +148,7 @@ __all__ = [
'unstack',
'sequence_enumerate',
'unique',
'unique_with_counts',
'expand',
'sequence_concat',
'scale',
......@@ -12277,6 +12278,58 @@ def unique(x, dtype='int32'):
return out, index
def unique_with_counts(x, dtype='int32'):
"""
**unique**
Return a unique tensor for `x` and an index tensor pointing to this unique tensor.
Args:
x(Variable): A 1-D input tensor.
dtype(np.dtype|core.VarDesc.VarType|str): The type of index tensor: int32, int64.
Returns:
tuple: (out, index, count). `out` is the unique tensor for `x`, with identical dtype to `x`, and \
`index` is an index tensor pointing to `out`, by which user can recover the original `x` tensor, \
`count` is count of unqiue element in the `x`.
Examples:
.. code-block:: python
import numpy as np
import paddle.fluid as fluid
x = fluid.layers.assign(np.array([2, 3, 3, 1, 5, 3], dtype='int32'))
out, index, count = fluid.layers.unique_with_counts(x) # out is [2, 3, 1, 5]; index is [0, 1, 1, 2, 3, 1]
# count is [1, 3, 1, 1]
"""
if not (dtype == 'int32' or dtype == 'int64'):
raise TypeError(
"Op unique_with_counts, index dtype must be int32 or int64")
if x is None or len(x.shape) != 1:
raise ValueError(
"Op unique_with_counts, x must not be null and size of dim must be 1"
)
helper = LayerHelper("unique_with_counts", **locals())
out = helper.create_variable_for_type_inference(dtype=x.dtype)
index = helper.create_variable_for_type_inference(dtype)
count = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type='unique_with_counts',
inputs={'X': x},
attrs={'dtype': convert_np_dtype_to_dtype_(dtype)},
outputs={'Out': [out],
'Index': [index],
'Count': [count]})
return out, index, count
def deformable_conv(input,
offset,
mask,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle.fluid.core as core
from paddle.fluid.op import Operator
class TestUniqueWithCountsOp(OpTest):
def setUp(self):
self.op_type = "unique_with_counts"
self.init_config()
def test_check_output(self):
self.check_output()
def init_config(self):
self.inputs = {'X': np.array([2, 3, 3, 1, 5, 3], dtype='int64'), }
self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)}
self.outputs = {
'Out': np.array(
[2, 3, 1, 5], dtype='int64'),
'Index': np.array(
[0, 1, 1, 2, 3, 1], dtype='int32'),
'Count': np.array(
[1, 3, 1, 1], dtype='int32')
}
class TestOne(TestUniqueWithCountsOp):
def init_config(self):
self.inputs = {'X': np.array([2], dtype='int64'), }
self.attrs = {'dtype': int(core.VarDesc.VarType.INT32)}
self.outputs = {
'Out': np.array(
[2], dtype='int64'),
'Index': np.array(
[0], dtype='int32'),
'Count': np.array(
[1], dtype='int32')
}
class TestRandom(TestUniqueWithCountsOp):
def init_config(self):
input_data = np.random.randint(0, 100, (2000, ), dtype='int64')
self.inputs = {'X': input_data}
self.attrs = {'dtype': int(core.VarDesc.VarType.INT64)}
np_unique, np_index, reverse_index = np.unique(self.inputs['X'], True,
True)
np_tuple = [(np_unique[i], np_index[i]) for i in range(len(np_unique))]
np_tuple.sort(key=lambda x: x[1])
target_out = np.array([i[0] for i in np_tuple], dtype='int64')
target_index = np.array(
[list(target_out).index(i) for i in self.inputs['X']],
dtype='int64')
count = [0 for i in range(len(np_unique))]
for i in range(target_index.shape[0]):
count[target_index[i]] += 1
target_count = np.array(count, dtype='int64')
self.outputs = {
'Out': target_out,
'Index': target_index,
'Count': target_count
}
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册