未验证 提交 be746adf 编写于 作者: Y Yuang Liu 提交者: GitHub

[operator migration] Migrate kernel of unique consecutive op. (#44228)

上级 f1111f3c
......@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/unique_consecutive_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
......@@ -118,11 +117,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(unique_consecutive,
ops::UniqueConsecutiveOp,
ops::UniqueConsecutiveOpMaker);
REGISTER_OP_CPU_KERNEL(unique_consecutive,
ops::UniqueConsecutiveKernel<phi::CPUContext, float>,
ops::UniqueConsecutiveKernel<phi::CPUContext, double>,
ops::UniqueConsecutiveKernel<phi::CPUContext, int32_t>,
ops::UniqueConsecutiveKernel<phi::CPUContext, int64_t>);
REGISTER_OP_VERSION(unique_consecutive)
.AddCheckpoint(
R"ROC(
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/unique_op.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/unique_functor.h"
namespace phi {
namespace paddle {
namespace operators {
template <typename InT, typename IndexT>
static void UniqueConsecutiveFlattendTensor(
const framework::ExecutionContext& context,
const framework::Tensor& in,
framework::Tensor* out,
bool return_inverse,
bool return_counts) {
template <typename InT, typename IndexT, typename Context>
static void UniqueConsecutiveFlattenedTensor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
bool return_inverse,
bool return_counts,
DenseTensor* inverse,
DenseTensor* count) {
const InT* in_data = in.data<InT>();
std::vector<InT> out_vec(in.numel());
std::vector<IndexT> inverse_vec(in.numel());
......@@ -65,27 +60,57 @@ static void UniqueConsecutiveFlattendTensor(
out_vec.resize(output_size);
out->Resize(phi::make_ddim({output_size}));
auto* out_data = out->mutable_data<InT>(context.GetPlace());
auto* out_data = context.template Alloc<InT>(out);
std::copy(out_vec.begin(), out_vec.end(), out_data);
if (return_inverse) {
auto* inverse = context.Output<framework::Tensor>("Index");
inverse->Resize(phi::make_ddim({in.numel()}));
auto* inverse_data = inverse->mutable_data<IndexT>(context.GetPlace());
auto* inverse_data = context.template Alloc<IndexT>(inverse);
std::copy(inverse_vec.begin(), inverse_vec.end(), inverse_data);
}
if (return_counts) {
auto* count = context.Output<framework::Tensor>("Counts");
count->Resize(phi::make_ddim({out->numel()}));
auto* counts_data = count->mutable_data<IndexT>(context.GetPlace());
auto* counts_data = context.template Alloc<IndexT>(count);
std::copy(counts_vec.begin(), counts_vec.end(), counts_data);
}
}
template <class ForwardIt, typename InT, typename IndexT>
template <typename Context, typename InT>
struct UniqueConsecutiveFlattenedTensorFunctor {
const Context& ctx_;
const DenseTensor& in_;
DenseTensor* out_;
const bool return_inverse_;
const bool return_counts_;
DenseTensor* inverse_;
DenseTensor* count_;
UniqueConsecutiveFlattenedTensorFunctor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
bool return_inverse,
bool return_counts,
DenseTensor* inverse,
DenseTensor* count)
: ctx_(context),
in_(in),
out_(out),
return_inverse_(return_inverse),
return_counts_(return_counts),
inverse_(inverse),
count_(count) {}
template <typename IndexT>
void apply() const {
UniqueConsecutiveFlattenedTensor<InT, IndexT, Context>(
ctx_, in_, out_, return_inverse_, return_counts_, inverse_, count_);
}
};
template <typename Context, class ForwardIt, typename InT, typename IndexT>
static ForwardIt UniqueConsecutiveDimImpl(
const framework::ExecutionContext& context,
const Context& context,
ForwardIt first,
ForwardIt last,
const std::vector<IndexT>& sorted_indices_vec,
......@@ -104,7 +129,7 @@ static ForwardIt UniqueConsecutiveDimImpl(
while (++first != last) {
int64_t idx_first = std::distance(begin, first);
int64_t idx_result = std::distance(begin, result);
if (!Equal<InT>(*result, *first)) {
if (!phi::funcs::Equal<InT>(*result, *first)) {
if (++result != first) {
*result = std::move(*first);
}
......@@ -116,13 +141,15 @@ static ForwardIt UniqueConsecutiveDimImpl(
return ++result;
}
template <typename DeviceContext, typename InT, typename IndexT>
static void UniqueConsecutiveDim(const framework::ExecutionContext& context,
const framework::Tensor& in,
framework::Tensor* out,
template <typename Context, typename InT, typename IndexT>
static void UniqueConsecutiveDim(const Context& context,
const DenseTensor& in,
DenseTensor* out,
bool return_inverse,
bool return_counts,
int axis) {
int axis,
DenseTensor* inverse,
DenseTensor* count) {
// transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2]
std::vector<int> permute(in.dims().size());
std::iota(permute.begin(), permute.end(), 0);
......@@ -131,15 +158,14 @@ static void UniqueConsecutiveDim(const framework::ExecutionContext& context,
std::vector<int64_t> in_trans_dims_vec(phi::vectorize(in.dims()));
in_trans_dims_vec[axis] = in.dims()[0];
in_trans_dims_vec[0] = in.dims()[axis];
framework::Tensor in_trans;
framework::DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec);
DenseTensor in_trans;
DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec);
in_trans.Resize(in_trans_dims);
in_trans.mutable_data<InT>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
TransCompute<DeviceContext, InT>(
in.dims().size(), dev_ctx, in, &in_trans, permute);
context.template Alloc<InT>(&in_trans);
phi::funcs::TransCompute<Context, InT>(
in.dims().size(), context, in, &in_trans, permute);
// reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2]
framework::DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1);
DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1);
in_trans.Resize(in_trans_flat_dims);
std::vector<IndexT> sorted_indices_vec(in_trans.dims()[0]);
......@@ -148,140 +174,88 @@ static void UniqueConsecutiveDim(const framework::ExecutionContext& context,
const InT* in_trans_data = in_trans.data<InT>();
// sort tensor according to indices
framework::Tensor input_sorted;
DenseTensor input_sorted;
input_sorted.Resize(in_trans_dims);
input_sorted.mutable_data<InT>(context.GetPlace());
context.template Alloc<InT>(&input_sorted);
InT* input_sorted_data = input_sorted.data<InT>();
for (size_t i = 0; i < sorted_indices_vec.size(); ++i) {
memcpy(input_sorted_data + i * col,
in_trans_data + static_cast<int64_t>(sorted_indices_vec[i]) * col,
col * sizeof(InT));
}
std::vector<framework::Tensor> input_unbind = Unbind(input_sorted);
std::vector<DenseTensor> input_unbind = phi::funcs::Unbind(input_sorted);
std::vector<IndexT> inverse_vec(sorted_indices_vec.size(), 0);
std::vector<IndexT> counts_vec(sorted_indices_vec.size(), 0);
auto last =
UniqueConsecutiveDimImpl<std::vector<framework::Tensor>::iterator, InT>(
context,
input_unbind.begin(),
input_unbind.end(),
sorted_indices_vec,
&inverse_vec,
&counts_vec);
auto last = UniqueConsecutiveDimImpl<Context,
std::vector<DenseTensor>::iterator,
InT>(context,
input_unbind.begin(),
input_unbind.end(),
sorted_indices_vec,
&inverse_vec,
&counts_vec);
input_unbind.erase(last, input_unbind.end());
counts_vec.erase(counts_vec.begin() + input_unbind.size(), counts_vec.end());
math::ConcatFunctor<DeviceContext, InT> concat_functor;
framework::Tensor out_trans;
phi::funcs::ConcatFunctor<Context, InT> concat_functor;
DenseTensor out_trans;
std::vector<int64_t> out_trans_dims_vec = in_trans_dims_vec;
out_trans_dims_vec[0] = input_unbind.size();
out_trans.Resize(phi::make_ddim(out_trans_dims_vec));
out_trans.mutable_data<InT>(context.GetPlace());
context.template Alloc<InT>(&out_trans);
std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]);
out->Resize(phi::make_ddim(out_trans_dims_vec));
out->mutable_data<InT>(context.GetPlace());
concat_functor(dev_ctx, input_unbind, 0, &out_trans);
TransCompute<DeviceContext, InT>(
out_trans.dims().size(), dev_ctx, out_trans, out, permute);
context.template Alloc<InT>(out);
concat_functor(context, input_unbind, 0, &out_trans);
phi::funcs::TransCompute<Context, InT>(
out_trans.dims().size(), context, out_trans, out, permute);
if (return_inverse) {
auto* inverse = context.Output<framework::Tensor>("Index");
framework::TensorFromVector(inverse_vec, context.device_context(), inverse);
paddle::framework::TensorFromVector(inverse_vec, context, inverse);
}
if (return_counts) {
auto* count = context.Output<framework::Tensor>("Counts");
framework::TensorFromVector(counts_vec, context.device_context(), count);
paddle::framework::TensorFromVector(counts_vec, context, count);
}
}
template <typename DeviceContext, typename InT>
struct UniqueConsecutiveFlattendTensorFunctor {
const framework::ExecutionContext& ctx_;
const framework::Tensor& in_;
framework::Tensor* out_;
const bool return_inverse_;
const bool return_counts_;
UniqueConsecutiveFlattendTensorFunctor(
const framework::ExecutionContext& context,
const framework::Tensor& in,
framework::Tensor* out,
bool return_inverse,
bool return_counts)
: ctx_(context),
in_(in),
out_(out),
return_inverse_(return_inverse),
return_counts_(return_counts) {}
template <typename IndexT>
void apply() const {
UniqueConsecutiveFlattendTensor<InT, IndexT>(
ctx_, in_, out_, return_inverse_, return_counts_);
}
};
template <typename DeviceContext, typename InT>
template <typename Context, typename InT>
struct UniqueConsecutiveDimFunctor {
const framework::ExecutionContext& ctx_;
const framework::Tensor& in_;
framework::Tensor* out_;
const Context& ctx_;
const DenseTensor& in_;
DenseTensor* out_;
const int axis_;
const bool return_inverse_;
const bool return_counts_;
UniqueConsecutiveDimFunctor(const framework::ExecutionContext& context,
const framework::Tensor& in,
framework::Tensor* out,
DenseTensor* inverse_;
DenseTensor* count_;
UniqueConsecutiveDimFunctor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
const int axis,
bool return_inverse,
bool return_counts)
bool return_counts,
DenseTensor* inverse,
DenseTensor* count)
: ctx_(context),
in_(in),
out_(out),
axis_(axis),
return_inverse_(return_inverse),
return_counts_(return_counts) {}
return_counts_(return_counts),
inverse_(inverse),
count_(count) {}
template <typename IndexT>
void apply() const {
UniqueConsecutiveDim<DeviceContext, InT, IndexT>(
ctx_, in_, out_, return_inverse_, return_counts_, axis_);
UniqueConsecutiveDim<Context, InT, IndexT>(ctx_,
in_,
out_,
return_inverse_,
return_counts_,
axis_,
inverse_,
count_);
}
};
template <typename DeviceContext, typename T>
class UniqueConsecutiveKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto data_type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
if (data_type == framework::proto::VarType::INT32) {
PADDLE_ENFORCE_LE(
x->numel(),
INT_MAX,
platform::errors::InvalidArgument(
"The number of elements in Input(X) should be less than or "
"equal to INT_MAX, but received num is %d. Please set `dtype` to "
"int64.",
x->numel()));
}
std::vector<int> axis_vec = context.Attr<std::vector<int>>("axis");
bool return_inverse = context.Attr<bool>("return_inverse");
bool return_counts = context.Attr<bool>("return_counts");
if (axis_vec.empty()) {
framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattendTensorFunctor<DeviceContext, T>(
context, *x, out, return_inverse, return_counts));
} else {
int axis = axis_vec[0];
framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimFunctor<DeviceContext, T>(
context, *x, out, axis, return_inverse, return_counts));
}
}
};
} // namespace operators
} // namespace paddle
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/unique_consecutive_kernel.h"
#include "paddle/phi/kernels/cpu/unique_consecutive_functor.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/data_type.h"
namespace phi {
template <typename T, typename Context>
void UniqueConsecutiveKernel(const Context& dev_ctx,
const DenseTensor& x,
bool return_inverse,
bool return_counts,
const std::vector<int>& axis,
int dtype,
DenseTensor* out,
DenseTensor* index,
DenseTensor* counts) {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(dtype);
if (data_type == paddle::framework::proto::VarType::INT32) {
PADDLE_ENFORCE_LE(
x.numel(),
INT_MAX,
phi::errors::InvalidArgument(
"The number of elements in Input(X) should be less than or "
"equal to INT_MAX, but received num is %d. Please set `dtype` to "
"int64.",
x.numel()));
}
if (axis.empty()) {
paddle::framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattenedTensorFunctor<Context, T>(
dev_ctx, x, out, return_inverse, return_counts, index, counts));
} else {
int valid_axis = axis[0];
paddle::framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimFunctor<Context, T>(dev_ctx,
x,
out,
valid_axis,
return_inverse,
return_counts,
index,
counts));
}
}
} // namespace phi
PD_REGISTER_KERNEL(unique_consecutive,
CPU,
ALL_LAYOUT,
phi::UniqueConsecutiveKernel,
float,
double,
int32_t,
int64_t) {}
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <thrust/adjacent_difference.h>
#include <thrust/device_vector.h>
#include <thrust/execution_policy.h>
......@@ -22,13 +24,204 @@ limitations under the License. */
#include <iostream>
#include <vector>
#include "paddle/fluid/framework/tensor_util.h" // TensorToVector()
#include "paddle/fluid/operators/unique_consecutive_op.h" // TransComute()
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/unique_functor.h"
namespace phi {
// The core logic of computing Unique Consecutive for a flattend Tensor
template <typename Context,
typename InT,
typename IndexT,
typename equal_T,
typename not_equal_T>
static void UniqueConsecutiveFlattenedCUDATensor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
bool return_inverse,
bool return_counts,
equal_T equal,
not_equal_T not_equal,
int64_t num_input,
DenseTensor* inverse,
DenseTensor* counts) {
// 0. Preparation
DenseTensor in_hat;
phi::Copy(context, in, context.GetPlace(), false, &in_hat);
auto in_data_hat = context.template Alloc<InT>(&in_hat);
DenseTensor sorted_indices;
sorted_indices.Resize(phi::make_ddim({num_input}));
auto sorted_indices_data = context.template Alloc<IndexT>(&sorted_indices);
thrust::sequence(
thrust::device, sorted_indices_data, sorted_indices_data + num_input);
// 1. Calculate op result: 'out'
DenseTensor range;
range.Resize(phi::make_ddim({num_input + 1}));
auto range_data_ptr = context.template Alloc<IndexT>(&range);
thrust::sequence(
thrust::device, range_data_ptr, range_data_ptr + num_input + 1);
phi::Copy(context, in_hat, context.GetPlace(), false, out);
int num_out;
auto out_data = context.template Alloc<InT>(out);
num_out =
thrust::unique_by_key(
thrust::device, out_data, out_data + num_input, range_data_ptr, equal)
.first -
out_data;
out->Resize(phi::make_ddim({num_out}));
// 2. Calculate inverse index: 'inverse'
if (return_inverse) {
inverse->Resize(phi::make_ddim({num_input}));
auto inverse_data = context.template Alloc<IndexT>(inverse);
DenseTensor inv_loc;
inv_loc.Resize(phi::make_ddim({num_input}));
auto inv_loc_data_ptr = context.template Alloc<IndexT>(&inv_loc);
thrust::adjacent_difference(thrust::device,
in_data_hat,
in_data_hat + num_input,
inv_loc_data_ptr,
not_equal);
thrust::device_ptr<IndexT> inv_loc_data_dev(inv_loc_data_ptr);
inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault
thrust::inclusive_scan(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + num_input,
inv_loc_data_ptr);
thrust::scatter(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + num_input,
sorted_indices_data,
inverse_data);
}
// 3. Calculate 'counts'
if (return_counts) {
counts->Resize(phi::make_ddim({num_out}));
auto count_data = context.template Alloc<IndexT>(counts);
// init 'count_data' as 0
thrust::fill(thrust::device, count_data, count_data + num_out, 0);
thrust::device_ptr<IndexT> range_data_ptr_dev(range_data_ptr);
range_data_ptr_dev[num_out] = num_input;
thrust::adjacent_difference(thrust::device,
range_data_ptr + 1,
range_data_ptr + num_out + 1,
count_data);
}
}
// functor for processing a flattend Tensor
template <typename Context, typename InT>
struct UniqueConsecutiveFlattenedCUDAFunctor {
const Context& ctx_;
const DenseTensor& in_;
DenseTensor* out_;
const bool return_inverse_;
const bool return_counts_;
DenseTensor* inverse_;
DenseTensor* count_;
UniqueConsecutiveFlattenedCUDAFunctor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
bool return_inverse,
bool return_counts,
DenseTensor* inverse,
DenseTensor* count)
: ctx_(context),
in_(in),
out_(out),
return_inverse_(return_inverse),
return_counts_(return_counts),
inverse_(inverse),
count_(count) {}
template <typename IndexT>
void apply() const {
UniqueConsecutiveFlattenedCUDATensor<Context, InT, IndexT>(
ctx_,
in_,
out_,
return_inverse_,
return_counts_,
thrust::equal_to<InT>(),
thrust::not_equal_to<InT>(),
in_.numel(),
inverse_,
count_);
}
};
namespace paddle {
namespace operators {
// The logic of compute unique with axis required, it's a little different
// from above function
template <typename Context,
typename InT,
typename IndexT,
typename equal_T,
typename not_equal_T>
static void ComputeUniqueConsecutiveDims(const Context& context,
DenseTensor* sorted_indices,
IndexT* sorted_indices_data,
DenseTensor* out,
bool return_inverse,
bool return_counts,
equal_T equal,
not_equal_T not_equal,
int64_t row,
DenseTensor* inverse,
DenseTensor* counts) {
// 1. inverse indices: 'inverse'
inverse->Resize(phi::make_ddim({row}));
auto inverse_data = context.template Alloc<IndexT>(inverse);
DenseTensor inv_loc;
inv_loc.Resize(phi::make_ddim({row}));
auto inv_loc_data_ptr = context.template Alloc<IndexT>(&inv_loc);
thrust::adjacent_difference(thrust::device,
sorted_indices_data,
sorted_indices_data + row,
inv_loc_data_ptr,
not_equal);
thrust::device_ptr<IndexT> inv_loc_data_dev(inv_loc_data_ptr);
inv_loc_data_dev[0] = 0;
thrust::inclusive_scan(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + row,
inv_loc_data_ptr);
thrust::scatter(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + row,
sorted_indices_data,
inverse_data);
// 2. sorted indices
DenseTensor range;
range.Resize(phi::make_ddim({row + 1}));
auto range_data_ptr = context.template Alloc<IndexT>(&range);
thrust::sequence(thrust::device, range_data_ptr, range_data_ptr + row + 1);
int num_out;
num_out = thrust::unique_by_key(thrust::device,
sorted_indices_data,
sorted_indices_data + row,
range_data_ptr,
equal)
.first -
sorted_indices_data;
thrust::device_ptr<IndexT> range_data_ptr_dev(range_data_ptr);
range_data_ptr_dev[num_out] = row;
sorted_indices->Resize(phi::make_ddim({num_out}));
using Tensor = framework::Tensor;
// 3. counts: 'counts'
counts->Resize(phi::make_ddim({num_out}));
auto count_data = context.template Alloc<IndexT>(counts);
thrust::fill(thrust::device, count_data, count_data + row, 0);
thrust::adjacent_difference(
thrust::device, range_data_ptr + 1, range_data_ptr + row + 1, count_data);
}
// Binary function 'equal_to'
template <typename InT>
......@@ -73,11 +266,11 @@ struct BinaryNotEqual {
};
// index_select() function for Tensor
template <typename InT, typename IndexT>
void IndexSelect(const framework::ExecutionContext& context,
const Tensor& input,
const Tensor& index,
Tensor* output,
template <typename Context, typename InT, typename IndexT>
void IndexSelect(const Context& context,
const DenseTensor& input,
const DenseTensor& index,
DenseTensor* output,
int dim) {
auto input_dim = input.dims();
auto input_dim_size = input_dim.size();
......@@ -100,17 +293,15 @@ void IndexSelect(const framework::ExecutionContext& context,
std::vector<InT> input_vec;
std::vector<IndexT> index_vec;
paddle::framework::TensorToVector(
input, context.device_context(), &input_vec);
paddle::framework::TensorToVector(
index, context.device_context(), &index_vec);
paddle::framework::TensorToVector(input, context, &input_vec);
paddle::framework::TensorToVector(index, context, &index_vec);
std::vector<InT> out_vec(output->numel());
for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_GE(
index_vec[i],
0,
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
......@@ -119,7 +310,7 @@ void IndexSelect(const framework::ExecutionContext& context,
PADDLE_ENFORCE_LT(
index_vec[i],
input_dim[dim],
platform::errors::InvalidArgument(
phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input "
"value.",
......@@ -139,162 +330,21 @@ void IndexSelect(const framework::ExecutionContext& context,
}
}
}
output->mutable_data<InT>(context.GetPlace());
framework::TensorFromVector(out_vec, context.device_context(), output);
context.template Alloc<InT>(output);
paddle::framework::TensorFromVector(out_vec, context, output);
output->Resize(output_dim);
}
// The core logic of computing Unique Consecutive for a flattend Tensor
template <typename InT, typename IndexT, typename equal_T, typename not_equal_T>
static void UniqueConsecutiveFlattendCUDATensor(
const framework::ExecutionContext& context,
const Tensor& in,
Tensor* out,
bool return_inverse,
bool return_counts,
equal_T equal,
not_equal_T not_equal,
int64_t num_input) {
// 0. Prepration
Tensor in_hat;
framework::TensorCopy(in, context.GetPlace(), &in_hat);
auto in_data_hat = in_hat.mutable_data<InT>(context.GetPlace());
Tensor sorted_indices;
sorted_indices.Resize(phi::make_ddim({num_input}));
auto sorted_indices_data =
sorted_indices.mutable_data<IndexT>(context.GetPlace());
thrust::sequence(
thrust::device, sorted_indices_data, sorted_indices_data + num_input);
// 1. Calculate op result: 'out'
Tensor range;
range.Resize(phi::make_ddim({num_input + 1}));
auto range_data_ptr = range.mutable_data<IndexT>(context.GetPlace());
thrust::sequence(
thrust::device, range_data_ptr, range_data_ptr + num_input + 1);
framework::TensorCopy(in_hat, context.GetPlace(), out);
int num_out;
auto out_data = out->mutable_data<InT>(context.GetPlace());
num_out =
thrust::unique_by_key(
thrust::device, out_data, out_data + num_input, range_data_ptr, equal)
.first -
out_data;
out->Resize(phi::make_ddim({num_out}));
// 2. Calculate inverse index: 'inverse'
if (return_inverse) {
Tensor* inverse = context.Output<Tensor>("Index");
inverse->Resize(phi::make_ddim({num_input}));
auto inverse_data = inverse->mutable_data<IndexT>(context.GetPlace());
Tensor inv_loc;
inv_loc.Resize(phi::make_ddim({num_input}));
auto inv_loc_data_ptr = inv_loc.mutable_data<IndexT>(context.GetPlace());
thrust::adjacent_difference(thrust::device,
in_data_hat,
in_data_hat + num_input,
inv_loc_data_ptr,
not_equal);
thrust::device_ptr<IndexT> inv_loc_data_dev(inv_loc_data_ptr);
inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault
thrust::inclusive_scan(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + num_input,
inv_loc_data_ptr);
thrust::scatter(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + num_input,
sorted_indices_data,
inverse_data);
}
// 3. Calculate 'counts'
if (return_counts) {
Tensor* counts = context.Output<Tensor>("Counts");
counts->Resize(phi::make_ddim({num_out}));
auto count_data = counts->mutable_data<IndexT>(context.GetPlace());
// init 'count_data' as 0
thrust::fill(thrust::device, count_data, count_data + num_out, 0);
thrust::device_ptr<IndexT> range_data_ptr_dev(range_data_ptr);
range_data_ptr_dev[num_out] = num_input;
thrust::adjacent_difference(thrust::device,
range_data_ptr + 1,
range_data_ptr + num_out + 1,
count_data);
}
}
// The logic of compute unique with axis required, it's a little different
// from above function
template <typename InT, typename IndexT, typename equal_T, typename not_equal_T>
static void ComputeUniqueConsecutiveDims(
const framework::ExecutionContext& context,
Tensor* sorted_indices,
IndexT* sorted_indices_data,
Tensor* out,
bool return_inverse,
bool return_counts,
equal_T equal,
not_equal_T not_equal,
int64_t row) {
// 1. inverse indices: 'inverse'
Tensor* inverse = context.Output<Tensor>("Index");
inverse->Resize(phi::make_ddim({row}));
auto inverse_data = inverse->mutable_data<IndexT>(context.GetPlace());
Tensor inv_loc;
inv_loc.Resize(phi::make_ddim({row}));
auto inv_loc_data_ptr = inv_loc.mutable_data<IndexT>(context.GetPlace());
thrust::adjacent_difference(thrust::device,
sorted_indices_data,
sorted_indices_data + row,
inv_loc_data_ptr,
not_equal);
thrust::device_ptr<IndexT> inv_loc_data_dev(inv_loc_data_ptr);
inv_loc_data_dev[0] = 0;
thrust::inclusive_scan(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + row,
inv_loc_data_ptr);
thrust::scatter(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + row,
sorted_indices_data,
inverse_data);
// 2. sorted indices
Tensor range;
range.Resize(phi::make_ddim({row + 1}));
auto range_data_ptr = range.mutable_data<IndexT>(context.GetPlace());
thrust::sequence(thrust::device, range_data_ptr, range_data_ptr + row + 1);
int num_out;
num_out = thrust::unique_by_key(thrust::device,
sorted_indices_data,
sorted_indices_data + row,
range_data_ptr,
equal)
.first -
sorted_indices_data;
thrust::device_ptr<IndexT> range_data_ptr_dev(range_data_ptr);
range_data_ptr_dev[num_out] = row;
sorted_indices->Resize(phi::make_ddim({num_out}));
// 3. counts: 'counts'
Tensor* counts = context.Output<Tensor>("Counts");
counts->Resize(phi::make_ddim({num_out}));
auto count_data = counts->mutable_data<IndexT>(context.GetPlace());
thrust::fill(thrust::device, count_data, count_data + row, 0);
thrust::adjacent_difference(
thrust::device, range_data_ptr + 1, range_data_ptr + row + 1, count_data);
}
// Calculate unique consecutive when 'axis' is set
template <typename DeviceContext, typename InT, typename IndexT>
static void UniqueConsecutiveDimsCUDATensor(
const framework::ExecutionContext& context,
const Tensor& in,
Tensor* out,
bool return_inverse,
bool return_counts,
int axis) {
template <typename Context, typename InT, typename IndexT>
static void UniqueConsecutiveDimsCUDATensor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
bool return_inverse,
bool return_counts,
int axis,
DenseTensor* inverse,
DenseTensor* counts) {
// 1. Transpose & reshape
// Transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2]
std::vector<int> permute(in.dims().size());
......@@ -304,19 +354,18 @@ static void UniqueConsecutiveDimsCUDATensor(
std::vector<int64_t> in_trans_dims_vec(phi::vectorize(in.dims()));
in_trans_dims_vec[axis] = in.dims()[0];
in_trans_dims_vec[0] = in.dims()[axis];
framework::Tensor in_trans;
framework::DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec);
DenseTensor in_trans;
DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec);
in_trans.Resize(in_trans_dims);
in_trans.mutable_data<InT>(context.GetPlace());
auto& dev_ctx = context.cuda_device_context();
TransCompute<DeviceContext, InT>(in.dims().size(), // num of dims
dev_ctx, // device
in, // original Tensor
&in_trans, // Tensor after reshape
permute); // index of axis
context.template Alloc<InT>(&in_trans);
phi::funcs::TransCompute<Context, InT>(in.dims().size(), // num of dims
context, // device
in, // original Tensor
&in_trans, // Tensor after reshape
permute); // index of axis
// Reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2]
framework::DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1);
DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1);
in_trans.Resize(in_trans_flat_dims);
// now 'in_trans' is 2D
......@@ -324,16 +373,15 @@ static void UniqueConsecutiveDimsCUDATensor(
int64_t row = in_trans.dims()[0];
const InT* in_trans_data = in_trans.data<InT>();
Tensor sorted_indices;
DenseTensor sorted_indices;
sorted_indices.Resize(phi::make_ddim({row}));
auto sorted_indices_data =
sorted_indices.mutable_data<IndexT>(context.GetPlace());
auto sorted_indices_data = context.template Alloc<IndexT>(&sorted_indices);
// 2. Calculate 'inverse', 'counts'
// Init index
thrust::sequence(
thrust::device, sorted_indices_data, sorted_indices_data + row);
ComputeUniqueConsecutiveDims<InT, IndexT>(
ComputeUniqueConsecutiveDims<Context, InT, IndexT>(
context,
&sorted_indices,
sorted_indices_data,
......@@ -342,143 +390,70 @@ static void UniqueConsecutiveDimsCUDATensor(
return_counts,
BinaryEqual<InT>(col, in_trans_data),
BinaryNotEqual<InT>(col, in_trans_data),
row);
row,
inverse,
counts);
// 3. Select indices and reshape back to get 'out'
Tensor out_trans;
DenseTensor out_trans;
std::vector<int64_t> out_trans_dims_vec = in_trans_dims_vec;
out_trans_dims_vec[0] = sorted_indices.numel();
out_trans.Resize(phi::make_ddim(out_trans_dims_vec));
out_trans.mutable_data<InT>(context.GetPlace());
context.template Alloc<InT>(&out_trans);
IndexSelect<InT, IndexT>(context, in_trans, sorted_indices, &out_trans, 0);
IndexSelect<Context, InT, IndexT>(
context, in_trans, sorted_indices, &out_trans, 0);
std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]);
out->Resize(phi::make_ddim(out_trans_dims_vec));
out->mutable_data<InT>(context.GetPlace());
std::vector<framework::Tensor> out_trans_unbind = Unbind(out_trans);
math::ConcatFunctor<DeviceContext, InT> concat_functor;
concat_functor(dev_ctx, out_trans_unbind, 0, &out_trans);
TransCompute<DeviceContext, InT>(
out_trans.dims().size(), dev_ctx, out_trans, out, permute);
context.template Alloc<InT>(out);
std::vector<DenseTensor> out_trans_unbind = phi::funcs::Unbind(out_trans);
phi::funcs::ConcatFunctor<Context, InT> concat_functor;
concat_functor(context, out_trans_unbind, 0, &out_trans);
phi::funcs::TransCompute<Context, InT>(
out_trans.dims().size(), context, out_trans, out, permute);
}
// functor for processing a flattend Tensor
template <typename DeviceContext, typename InT>
struct UniqueConsecutiveFlattendCUDAFunctor {
const framework::ExecutionContext& ctx_;
const Tensor& in_;
Tensor* out_;
const bool return_inverse_;
const bool return_counts_;
UniqueConsecutiveFlattendCUDAFunctor(
const framework::ExecutionContext& context,
const Tensor& in,
Tensor* out,
bool return_inverse,
bool return_counts)
: ctx_(context),
in_(in),
out_(out),
return_inverse_(return_inverse),
return_counts_(return_counts) {}
template <typename IndexT>
void apply() const {
UniqueConsecutiveFlattendCUDATensor<InT, IndexT>(
ctx_,
in_,
out_,
return_inverse_,
return_counts_,
thrust::equal_to<InT>(),
thrust::not_equal_to<InT>(),
in_.numel());
}
};
// functor for processing a multi-dimentional Tensor
template <typename DeviceContext, typename InT>
template <typename Context, typename InT>
struct UniqueConsecutiveDimsCUDAFunctor {
const framework::ExecutionContext& ctx_;
const Tensor& in_;
Tensor* out_;
const Context& ctx_;
const DenseTensor& in_;
DenseTensor* out_;
const int axis_;
const bool return_inverse_;
const bool return_counts_;
DenseTensor* inverse_;
DenseTensor* count_;
UniqueConsecutiveDimsCUDAFunctor(const framework::ExecutionContext& context,
const Tensor& in,
Tensor* out,
UniqueConsecutiveDimsCUDAFunctor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
const int axis,
bool return_inverse,
bool return_counts)
bool return_counts,
DenseTensor* inverse,
DenseTensor* count)
: ctx_(context),
in_(in),
out_(out),
axis_(axis),
return_inverse_(return_inverse),
return_counts_(return_counts) {}
return_counts_(return_counts),
inverse_(inverse),
count_(count) {}
template <typename IndexT>
void apply() const {
UniqueConsecutiveDimsCUDATensor<DeviceContext, InT, IndexT>(
ctx_, in_, out_, return_inverse_, return_counts_, axis_);
UniqueConsecutiveDimsCUDATensor<Context, InT, IndexT>(ctx_,
in_,
out_,
return_inverse_,
return_counts_,
axis_,
inverse_,
count_);
}
};
// Unique_Consecutive_op CUDA implementation.
template <typename InT>
class UniqueConsecutiveKernel<platform::CUDADeviceContext, InT>
: public framework::OpKernel<InT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto data_type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
if (data_type == framework::proto::VarType::INT32) {
PADDLE_ENFORCE_LE(
x->numel() + 1,
INT_MAX,
platform::errors::InvalidArgument(
"The number of elements in Input(X) should be less than or "
"equal to INT_MAX, but received num is %d. Please set `dtype` to "
"int64.",
x->numel()));
}
std::vector<int> axis_vec = context.Attr<std::vector<int>>("axis");
bool return_inverse = context.Attr<bool>("return_inverse");
bool return_counts = context.Attr<bool>("return_counts");
// if 'axis' is not required, flatten the Tensor.
if (axis_vec.empty()) {
framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattendCUDAFunctor<platform::CUDADeviceContext,
InT>(
context, *x, out, return_inverse, return_counts));
} else {
// 'axis' is required.
int axis = axis_vec[0];
framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimsCUDAFunctor<platform::CUDADeviceContext, InT>(
context, *x, out, axis, return_inverse, return_counts));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
unique_consecutive,
ops::UniqueConsecutiveKernel<paddle::platform::CUDADeviceContext, float>,
ops::UniqueConsecutiveKernel<paddle::platform::CUDADeviceContext, double>,
ops::UniqueConsecutiveKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::UniqueConsecutiveKernel<paddle::platform::CUDADeviceContext, int64_t>);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/unique_consecutive_kernel.h"
#include "paddle/phi/kernels/gpu/unique_consecutive_functor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/data_type.h"
namespace phi {
template <typename T, typename Context>
void UniqueConsecutiveKernel(const Context& dev_ctx,
const DenseTensor& x,
bool return_inverse,
bool return_counts,
const std::vector<int>& axis,
int dtype,
DenseTensor* out,
DenseTensor* index,
DenseTensor* counts) {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(dtype);
if (data_type == paddle::framework::proto::VarType::INT32) {
PADDLE_ENFORCE_LE(
x.numel() + 1,
INT_MAX,
phi::errors::InvalidArgument(
"The number of elements in Input(X) should be less than or "
"equal to INT_MAX, but received num is %d. Please set `dtype` to "
"int64.",
x.numel()));
}
// if 'axis' is not required, flatten the Tensor.
if (axis.empty()) {
paddle::framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattenedCUDAFunctor<Context, T>(
dev_ctx, x, out, return_inverse, return_counts, index, counts));
} else {
// 'axis' is required.
int valid_axis = axis[0];
paddle::framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimsCUDAFunctor<Context, T>(dev_ctx,
x,
out,
valid_axis,
return_inverse,
return_counts,
index,
counts));
}
}
} // namespace phi
PD_REGISTER_KERNEL(unique_consecutive,
GPU,
ALL_LAYOUT,
phi::UniqueConsecutiveKernel,
float,
double,
int32_t,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void UniqueConsecutiveKernel(const Context& dev_ctx,
const DenseTensor& x,
bool return_inverse,
bool return_counts,
const std::vector<int>& axis,
int dtype,
DenseTensor* out,
DenseTensor* index,
DenseTensor* counts);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature UniqueConsecutiveOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("unique_consecutive",
{"X"},
{"return_inverse", "return_counts", "axis", "dtype"},
{"Out", "Index", "Counts"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(unique_consecutive,
phi::UniqueConsecutiveOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册