未验证 提交 be746adf 编写于 作者: Y Yuang Liu 提交者: GitHub

[operator migration] Migrate kernel of unique consecutive op. (#44228)

上级 f1111f3c
...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/unique_consecutive_op.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
namespace paddle { namespace paddle {
...@@ -118,11 +117,6 @@ namespace ops = paddle::operators; ...@@ -118,11 +117,6 @@ namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(unique_consecutive, REGISTER_OP_WITHOUT_GRADIENT(unique_consecutive,
ops::UniqueConsecutiveOp, ops::UniqueConsecutiveOp,
ops::UniqueConsecutiveOpMaker); ops::UniqueConsecutiveOpMaker);
REGISTER_OP_CPU_KERNEL(unique_consecutive,
ops::UniqueConsecutiveKernel<phi::CPUContext, float>,
ops::UniqueConsecutiveKernel<phi::CPUContext, double>,
ops::UniqueConsecutiveKernel<phi::CPUContext, int32_t>,
ops::UniqueConsecutiveKernel<phi::CPUContext, int64_t>);
REGISTER_OP_VERSION(unique_consecutive) REGISTER_OP_VERSION(unique_consecutive)
.AddCheckpoint( .AddCheckpoint(
R"ROC( R"ROC(
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#pragma once #pragma once
#include <algorithm>
#include <cmath>
#include <numeric>
#include <set>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/concat_and_split.h"
#include "paddle/fluid/operators/transpose_op.h" #include "paddle/phi/core/dense_tensor.h"
#include "paddle/fluid/operators/unique_op.h" #include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/unique_functor.h"
namespace phi {
namespace paddle { template <typename InT, typename IndexT, typename Context>
namespace operators { static void UniqueConsecutiveFlattenedTensor(const Context& context,
template <typename InT, typename IndexT> const DenseTensor& in,
static void UniqueConsecutiveFlattendTensor( DenseTensor* out,
const framework::ExecutionContext& context, bool return_inverse,
const framework::Tensor& in, bool return_counts,
framework::Tensor* out, DenseTensor* inverse,
bool return_inverse, DenseTensor* count) {
bool return_counts) {
const InT* in_data = in.data<InT>(); const InT* in_data = in.data<InT>();
std::vector<InT> out_vec(in.numel()); std::vector<InT> out_vec(in.numel());
std::vector<IndexT> inverse_vec(in.numel()); std::vector<IndexT> inverse_vec(in.numel());
...@@ -65,27 +60,57 @@ static void UniqueConsecutiveFlattendTensor( ...@@ -65,27 +60,57 @@ static void UniqueConsecutiveFlattendTensor(
out_vec.resize(output_size); out_vec.resize(output_size);
out->Resize(phi::make_ddim({output_size})); out->Resize(phi::make_ddim({output_size}));
auto* out_data = out->mutable_data<InT>(context.GetPlace()); auto* out_data = context.template Alloc<InT>(out);
std::copy(out_vec.begin(), out_vec.end(), out_data); std::copy(out_vec.begin(), out_vec.end(), out_data);
if (return_inverse) { if (return_inverse) {
auto* inverse = context.Output<framework::Tensor>("Index");
inverse->Resize(phi::make_ddim({in.numel()})); inverse->Resize(phi::make_ddim({in.numel()}));
auto* inverse_data = inverse->mutable_data<IndexT>(context.GetPlace()); auto* inverse_data = context.template Alloc<IndexT>(inverse);
std::copy(inverse_vec.begin(), inverse_vec.end(), inverse_data); std::copy(inverse_vec.begin(), inverse_vec.end(), inverse_data);
} }
if (return_counts) { if (return_counts) {
auto* count = context.Output<framework::Tensor>("Counts");
count->Resize(phi::make_ddim({out->numel()})); count->Resize(phi::make_ddim({out->numel()}));
auto* counts_data = count->mutable_data<IndexT>(context.GetPlace()); auto* counts_data = context.template Alloc<IndexT>(count);
std::copy(counts_vec.begin(), counts_vec.end(), counts_data); std::copy(counts_vec.begin(), counts_vec.end(), counts_data);
} }
} }
template <class ForwardIt, typename InT, typename IndexT> template <typename Context, typename InT>
struct UniqueConsecutiveFlattenedTensorFunctor {
const Context& ctx_;
const DenseTensor& in_;
DenseTensor* out_;
const bool return_inverse_;
const bool return_counts_;
DenseTensor* inverse_;
DenseTensor* count_;
UniqueConsecutiveFlattenedTensorFunctor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
bool return_inverse,
bool return_counts,
DenseTensor* inverse,
DenseTensor* count)
: ctx_(context),
in_(in),
out_(out),
return_inverse_(return_inverse),
return_counts_(return_counts),
inverse_(inverse),
count_(count) {}
template <typename IndexT>
void apply() const {
UniqueConsecutiveFlattenedTensor<InT, IndexT, Context>(
ctx_, in_, out_, return_inverse_, return_counts_, inverse_, count_);
}
};
template <typename Context, class ForwardIt, typename InT, typename IndexT>
static ForwardIt UniqueConsecutiveDimImpl( static ForwardIt UniqueConsecutiveDimImpl(
const framework::ExecutionContext& context, const Context& context,
ForwardIt first, ForwardIt first,
ForwardIt last, ForwardIt last,
const std::vector<IndexT>& sorted_indices_vec, const std::vector<IndexT>& sorted_indices_vec,
...@@ -104,7 +129,7 @@ static ForwardIt UniqueConsecutiveDimImpl( ...@@ -104,7 +129,7 @@ static ForwardIt UniqueConsecutiveDimImpl(
while (++first != last) { while (++first != last) {
int64_t idx_first = std::distance(begin, first); int64_t idx_first = std::distance(begin, first);
int64_t idx_result = std::distance(begin, result); int64_t idx_result = std::distance(begin, result);
if (!Equal<InT>(*result, *first)) { if (!phi::funcs::Equal<InT>(*result, *first)) {
if (++result != first) { if (++result != first) {
*result = std::move(*first); *result = std::move(*first);
} }
...@@ -116,13 +141,15 @@ static ForwardIt UniqueConsecutiveDimImpl( ...@@ -116,13 +141,15 @@ static ForwardIt UniqueConsecutiveDimImpl(
return ++result; return ++result;
} }
template <typename DeviceContext, typename InT, typename IndexT> template <typename Context, typename InT, typename IndexT>
static void UniqueConsecutiveDim(const framework::ExecutionContext& context, static void UniqueConsecutiveDim(const Context& context,
const framework::Tensor& in, const DenseTensor& in,
framework::Tensor* out, DenseTensor* out,
bool return_inverse, bool return_inverse,
bool return_counts, bool return_counts,
int axis) { int axis,
DenseTensor* inverse,
DenseTensor* count) {
// transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2] // transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2]
std::vector<int> permute(in.dims().size()); std::vector<int> permute(in.dims().size());
std::iota(permute.begin(), permute.end(), 0); std::iota(permute.begin(), permute.end(), 0);
...@@ -131,15 +158,14 @@ static void UniqueConsecutiveDim(const framework::ExecutionContext& context, ...@@ -131,15 +158,14 @@ static void UniqueConsecutiveDim(const framework::ExecutionContext& context,
std::vector<int64_t> in_trans_dims_vec(phi::vectorize(in.dims())); std::vector<int64_t> in_trans_dims_vec(phi::vectorize(in.dims()));
in_trans_dims_vec[axis] = in.dims()[0]; in_trans_dims_vec[axis] = in.dims()[0];
in_trans_dims_vec[0] = in.dims()[axis]; in_trans_dims_vec[0] = in.dims()[axis];
framework::Tensor in_trans; DenseTensor in_trans;
framework::DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec); DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec);
in_trans.Resize(in_trans_dims); in_trans.Resize(in_trans_dims);
in_trans.mutable_data<InT>(context.GetPlace()); context.template Alloc<InT>(&in_trans);
auto& dev_ctx = context.template device_context<DeviceContext>(); phi::funcs::TransCompute<Context, InT>(
TransCompute<DeviceContext, InT>( in.dims().size(), context, in, &in_trans, permute);
in.dims().size(), dev_ctx, in, &in_trans, permute);
// reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] // reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2]
framework::DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1); DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1);
in_trans.Resize(in_trans_flat_dims); in_trans.Resize(in_trans_flat_dims);
std::vector<IndexT> sorted_indices_vec(in_trans.dims()[0]); std::vector<IndexT> sorted_indices_vec(in_trans.dims()[0]);
...@@ -148,140 +174,88 @@ static void UniqueConsecutiveDim(const framework::ExecutionContext& context, ...@@ -148,140 +174,88 @@ static void UniqueConsecutiveDim(const framework::ExecutionContext& context,
const InT* in_trans_data = in_trans.data<InT>(); const InT* in_trans_data = in_trans.data<InT>();
// sort tensor according to indices // sort tensor according to indices
framework::Tensor input_sorted; DenseTensor input_sorted;
input_sorted.Resize(in_trans_dims); input_sorted.Resize(in_trans_dims);
input_sorted.mutable_data<InT>(context.GetPlace()); context.template Alloc<InT>(&input_sorted);
InT* input_sorted_data = input_sorted.data<InT>(); InT* input_sorted_data = input_sorted.data<InT>();
for (size_t i = 0; i < sorted_indices_vec.size(); ++i) { for (size_t i = 0; i < sorted_indices_vec.size(); ++i) {
memcpy(input_sorted_data + i * col, memcpy(input_sorted_data + i * col,
in_trans_data + static_cast<int64_t>(sorted_indices_vec[i]) * col, in_trans_data + static_cast<int64_t>(sorted_indices_vec[i]) * col,
col * sizeof(InT)); col * sizeof(InT));
} }
std::vector<framework::Tensor> input_unbind = Unbind(input_sorted); std::vector<DenseTensor> input_unbind = phi::funcs::Unbind(input_sorted);
std::vector<IndexT> inverse_vec(sorted_indices_vec.size(), 0); std::vector<IndexT> inverse_vec(sorted_indices_vec.size(), 0);
std::vector<IndexT> counts_vec(sorted_indices_vec.size(), 0); std::vector<IndexT> counts_vec(sorted_indices_vec.size(), 0);
auto last = auto last = UniqueConsecutiveDimImpl<Context,
UniqueConsecutiveDimImpl<std::vector<framework::Tensor>::iterator, InT>( std::vector<DenseTensor>::iterator,
context, InT>(context,
input_unbind.begin(), input_unbind.begin(),
input_unbind.end(), input_unbind.end(),
sorted_indices_vec, sorted_indices_vec,
&inverse_vec, &inverse_vec,
&counts_vec); &counts_vec);
input_unbind.erase(last, input_unbind.end()); input_unbind.erase(last, input_unbind.end());
counts_vec.erase(counts_vec.begin() + input_unbind.size(), counts_vec.end()); counts_vec.erase(counts_vec.begin() + input_unbind.size(), counts_vec.end());
math::ConcatFunctor<DeviceContext, InT> concat_functor; phi::funcs::ConcatFunctor<Context, InT> concat_functor;
framework::Tensor out_trans; DenseTensor out_trans;
std::vector<int64_t> out_trans_dims_vec = in_trans_dims_vec; std::vector<int64_t> out_trans_dims_vec = in_trans_dims_vec;
out_trans_dims_vec[0] = input_unbind.size(); out_trans_dims_vec[0] = input_unbind.size();
out_trans.Resize(phi::make_ddim(out_trans_dims_vec)); out_trans.Resize(phi::make_ddim(out_trans_dims_vec));
out_trans.mutable_data<InT>(context.GetPlace()); context.template Alloc<InT>(&out_trans);
std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]);
out->Resize(phi::make_ddim(out_trans_dims_vec)); out->Resize(phi::make_ddim(out_trans_dims_vec));
out->mutable_data<InT>(context.GetPlace()); context.template Alloc<InT>(out);
concat_functor(dev_ctx, input_unbind, 0, &out_trans); concat_functor(context, input_unbind, 0, &out_trans);
TransCompute<DeviceContext, InT>( phi::funcs::TransCompute<Context, InT>(
out_trans.dims().size(), dev_ctx, out_trans, out, permute); out_trans.dims().size(), context, out_trans, out, permute);
if (return_inverse) { if (return_inverse) {
auto* inverse = context.Output<framework::Tensor>("Index"); paddle::framework::TensorFromVector(inverse_vec, context, inverse);
framework::TensorFromVector(inverse_vec, context.device_context(), inverse);
} }
if (return_counts) { if (return_counts) {
auto* count = context.Output<framework::Tensor>("Counts"); paddle::framework::TensorFromVector(counts_vec, context, count);
framework::TensorFromVector(counts_vec, context.device_context(), count);
} }
} }
template <typename DeviceContext, typename InT> template <typename Context, typename InT>
struct UniqueConsecutiveFlattendTensorFunctor {
const framework::ExecutionContext& ctx_;
const framework::Tensor& in_;
framework::Tensor* out_;
const bool return_inverse_;
const bool return_counts_;
UniqueConsecutiveFlattendTensorFunctor(
const framework::ExecutionContext& context,
const framework::Tensor& in,
framework::Tensor* out,
bool return_inverse,
bool return_counts)
: ctx_(context),
in_(in),
out_(out),
return_inverse_(return_inverse),
return_counts_(return_counts) {}
template <typename IndexT>
void apply() const {
UniqueConsecutiveFlattendTensor<InT, IndexT>(
ctx_, in_, out_, return_inverse_, return_counts_);
}
};
template <typename DeviceContext, typename InT>
struct UniqueConsecutiveDimFunctor { struct UniqueConsecutiveDimFunctor {
const framework::ExecutionContext& ctx_; const Context& ctx_;
const framework::Tensor& in_; const DenseTensor& in_;
framework::Tensor* out_; DenseTensor* out_;
const int axis_; const int axis_;
const bool return_inverse_; const bool return_inverse_;
const bool return_counts_; const bool return_counts_;
UniqueConsecutiveDimFunctor(const framework::ExecutionContext& context, DenseTensor* inverse_;
const framework::Tensor& in, DenseTensor* count_;
framework::Tensor* out,
UniqueConsecutiveDimFunctor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
const int axis, const int axis,
bool return_inverse, bool return_inverse,
bool return_counts) bool return_counts,
DenseTensor* inverse,
DenseTensor* count)
: ctx_(context), : ctx_(context),
in_(in), in_(in),
out_(out), out_(out),
axis_(axis), axis_(axis),
return_inverse_(return_inverse), return_inverse_(return_inverse),
return_counts_(return_counts) {} return_counts_(return_counts),
inverse_(inverse),
count_(count) {}
template <typename IndexT> template <typename IndexT>
void apply() const { void apply() const {
UniqueConsecutiveDim<DeviceContext, InT, IndexT>( UniqueConsecutiveDim<Context, InT, IndexT>(ctx_,
ctx_, in_, out_, return_inverse_, return_counts_, axis_); in_,
out_,
return_inverse_,
return_counts_,
axis_,
inverse_,
count_);
} }
}; };
template <typename DeviceContext, typename T>
class UniqueConsecutiveKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto data_type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
if (data_type == framework::proto::VarType::INT32) {
PADDLE_ENFORCE_LE(
x->numel(),
INT_MAX,
platform::errors::InvalidArgument(
"The number of elements in Input(X) should be less than or "
"equal to INT_MAX, but received num is %d. Please set `dtype` to "
"int64.",
x->numel()));
}
std::vector<int> axis_vec = context.Attr<std::vector<int>>("axis");
bool return_inverse = context.Attr<bool>("return_inverse");
bool return_counts = context.Attr<bool>("return_counts");
if (axis_vec.empty()) { } // namespace phi
framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattendTensorFunctor<DeviceContext, T>(
context, *x, out, return_inverse, return_counts));
} else {
int axis = axis_vec[0];
framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimFunctor<DeviceContext, T>(
context, *x, out, axis, return_inverse, return_counts));
}
}
};
} // namespace operators
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/unique_consecutive_kernel.h"
#include "paddle/phi/kernels/cpu/unique_consecutive_functor.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/data_type.h"
namespace phi {
template <typename T, typename Context>
void UniqueConsecutiveKernel(const Context& dev_ctx,
const DenseTensor& x,
bool return_inverse,
bool return_counts,
const std::vector<int>& axis,
int dtype,
DenseTensor* out,
DenseTensor* index,
DenseTensor* counts) {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(dtype);
if (data_type == paddle::framework::proto::VarType::INT32) {
PADDLE_ENFORCE_LE(
x.numel(),
INT_MAX,
phi::errors::InvalidArgument(
"The number of elements in Input(X) should be less than or "
"equal to INT_MAX, but received num is %d. Please set `dtype` to "
"int64.",
x.numel()));
}
if (axis.empty()) {
paddle::framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattenedTensorFunctor<Context, T>(
dev_ctx, x, out, return_inverse, return_counts, index, counts));
} else {
int valid_axis = axis[0];
paddle::framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimFunctor<Context, T>(dev_ctx,
x,
out,
valid_axis,
return_inverse,
return_counts,
index,
counts));
}
}
} // namespace phi
PD_REGISTER_KERNEL(unique_consecutive,
CPU,
ALL_LAYOUT,
phi::UniqueConsecutiveKernel,
float,
double,
int32_t,
int64_t) {}
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
You may obtain a copy of the License at // You may obtain a copy of the License at
//
http://www.apache.org/licenses/LICENSE-2.0 // http://www.apache.org/licenses/LICENSE-2.0
//
Unless required by applicable law or agreed to in writing, software // Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS, // distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
limitations under the License. */ // limitations under the License.
#pragma once
#include <thrust/adjacent_difference.h> #include <thrust/adjacent_difference.h>
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/execution_policy.h> #include <thrust/execution_policy.h>
...@@ -22,13 +24,204 @@ limitations under the License. */ ...@@ -22,13 +24,204 @@ limitations under the License. */
#include <iostream> #include <iostream>
#include <vector> #include <vector>
#include "paddle/fluid/framework/tensor_util.h" // TensorToVector() #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/unique_consecutive_op.h" // TransComute()
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/concat_and_split_functor.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/unique_functor.h"
namespace phi {
// The core logic of computing Unique Consecutive for a flattend Tensor
template <typename Context,
typename InT,
typename IndexT,
typename equal_T,
typename not_equal_T>
static void UniqueConsecutiveFlattenedCUDATensor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
bool return_inverse,
bool return_counts,
equal_T equal,
not_equal_T not_equal,
int64_t num_input,
DenseTensor* inverse,
DenseTensor* counts) {
// 0. Preparation
DenseTensor in_hat;
phi::Copy(context, in, context.GetPlace(), false, &in_hat);
auto in_data_hat = context.template Alloc<InT>(&in_hat);
DenseTensor sorted_indices;
sorted_indices.Resize(phi::make_ddim({num_input}));
auto sorted_indices_data = context.template Alloc<IndexT>(&sorted_indices);
thrust::sequence(
thrust::device, sorted_indices_data, sorted_indices_data + num_input);
// 1. Calculate op result: 'out'
DenseTensor range;
range.Resize(phi::make_ddim({num_input + 1}));
auto range_data_ptr = context.template Alloc<IndexT>(&range);
thrust::sequence(
thrust::device, range_data_ptr, range_data_ptr + num_input + 1);
phi::Copy(context, in_hat, context.GetPlace(), false, out);
int num_out;
auto out_data = context.template Alloc<InT>(out);
num_out =
thrust::unique_by_key(
thrust::device, out_data, out_data + num_input, range_data_ptr, equal)
.first -
out_data;
out->Resize(phi::make_ddim({num_out}));
// 2. Calculate inverse index: 'inverse'
if (return_inverse) {
inverse->Resize(phi::make_ddim({num_input}));
auto inverse_data = context.template Alloc<IndexT>(inverse);
DenseTensor inv_loc;
inv_loc.Resize(phi::make_ddim({num_input}));
auto inv_loc_data_ptr = context.template Alloc<IndexT>(&inv_loc);
thrust::adjacent_difference(thrust::device,
in_data_hat,
in_data_hat + num_input,
inv_loc_data_ptr,
not_equal);
thrust::device_ptr<IndexT> inv_loc_data_dev(inv_loc_data_ptr);
inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault
thrust::inclusive_scan(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + num_input,
inv_loc_data_ptr);
thrust::scatter(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + num_input,
sorted_indices_data,
inverse_data);
}
// 3. Calculate 'counts'
if (return_counts) {
counts->Resize(phi::make_ddim({num_out}));
auto count_data = context.template Alloc<IndexT>(counts);
// init 'count_data' as 0
thrust::fill(thrust::device, count_data, count_data + num_out, 0);
thrust::device_ptr<IndexT> range_data_ptr_dev(range_data_ptr);
range_data_ptr_dev[num_out] = num_input;
thrust::adjacent_difference(thrust::device,
range_data_ptr + 1,
range_data_ptr + num_out + 1,
count_data);
}
}
// functor for processing a flattend Tensor
template <typename Context, typename InT>
struct UniqueConsecutiveFlattenedCUDAFunctor {
const Context& ctx_;
const DenseTensor& in_;
DenseTensor* out_;
const bool return_inverse_;
const bool return_counts_;
DenseTensor* inverse_;
DenseTensor* count_;
UniqueConsecutiveFlattenedCUDAFunctor(const Context& context,
const DenseTensor& in,
DenseTensor* out,
bool return_inverse,
bool return_counts,
DenseTensor* inverse,
DenseTensor* count)
: ctx_(context),
in_(in),
out_(out),
return_inverse_(return_inverse),
return_counts_(return_counts),
inverse_(inverse),
count_(count) {}
template <typename IndexT>
void apply() const {
UniqueConsecutiveFlattenedCUDATensor<Context, InT, IndexT>(
ctx_,
in_,
out_,
return_inverse_,
return_counts_,
thrust::equal_to<InT>(),
thrust::not_equal_to<InT>(),
in_.numel(),
inverse_,
count_);
}
};
namespace paddle { // The logic of compute unique with axis required, it's a little different
namespace operators { // from above function
template <typename Context,
typename InT,
typename IndexT,
typename equal_T,
typename not_equal_T>
static void ComputeUniqueConsecutiveDims(const Context& context,
DenseTensor* sorted_indices,
IndexT* sorted_indices_data,
DenseTensor* out,
bool return_inverse,
bool return_counts,
equal_T equal,
not_equal_T not_equal,
int64_t row,
DenseTensor* inverse,
DenseTensor* counts) {
// 1. inverse indices: 'inverse'
inverse->Resize(phi::make_ddim({row}));
auto inverse_data = context.template Alloc<IndexT>(inverse);
DenseTensor inv_loc;
inv_loc.Resize(phi::make_ddim({row}));
auto inv_loc_data_ptr = context.template Alloc<IndexT>(&inv_loc);
thrust::adjacent_difference(thrust::device,
sorted_indices_data,
sorted_indices_data + row,
inv_loc_data_ptr,
not_equal);
thrust::device_ptr<IndexT> inv_loc_data_dev(inv_loc_data_ptr);
inv_loc_data_dev[0] = 0;
thrust::inclusive_scan(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + row,
inv_loc_data_ptr);
thrust::scatter(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + row,
sorted_indices_data,
inverse_data);
// 2. sorted indices
DenseTensor range;
range.Resize(phi::make_ddim({row + 1}));
auto range_data_ptr = context.template Alloc<IndexT>(&range);
thrust::sequence(thrust::device, range_data_ptr, range_data_ptr + row + 1);
int num_out;
num_out = thrust::unique_by_key(thrust::device,
sorted_indices_data,
sorted_indices_data + row,
range_data_ptr,
equal)
.first -
sorted_indices_data;
thrust::device_ptr<IndexT> range_data_ptr_dev(range_data_ptr);
range_data_ptr_dev[num_out] = row;
sorted_indices->Resize(phi::make_ddim({num_out}));
using Tensor = framework::Tensor; // 3. counts: 'counts'
counts->Resize(phi::make_ddim({num_out}));
auto count_data = context.template Alloc<IndexT>(counts);
thrust::fill(thrust::device, count_data, count_data + row, 0);
thrust::adjacent_difference(
thrust::device, range_data_ptr + 1, range_data_ptr + row + 1, count_data);
}
// Binary function 'equal_to' // Binary function 'equal_to'
template <typename InT> template <typename InT>
...@@ -73,11 +266,11 @@ struct BinaryNotEqual { ...@@ -73,11 +266,11 @@ struct BinaryNotEqual {
}; };
// index_select() function for Tensor // index_select() function for Tensor
template <typename InT, typename IndexT> template <typename Context, typename InT, typename IndexT>
void IndexSelect(const framework::ExecutionContext& context, void IndexSelect(const Context& context,
const Tensor& input, const DenseTensor& input,
const Tensor& index, const DenseTensor& index,
Tensor* output, DenseTensor* output,
int dim) { int dim) {
auto input_dim = input.dims(); auto input_dim = input.dims();
auto input_dim_size = input_dim.size(); auto input_dim_size = input_dim.size();
...@@ -100,17 +293,15 @@ void IndexSelect(const framework::ExecutionContext& context, ...@@ -100,17 +293,15 @@ void IndexSelect(const framework::ExecutionContext& context,
std::vector<InT> input_vec; std::vector<InT> input_vec;
std::vector<IndexT> index_vec; std::vector<IndexT> index_vec;
paddle::framework::TensorToVector( paddle::framework::TensorToVector(input, context, &input_vec);
input, context.device_context(), &input_vec); paddle::framework::TensorToVector(index, context, &index_vec);
paddle::framework::TensorToVector(
index, context.device_context(), &index_vec);
std::vector<InT> out_vec(output->numel()); std::vector<InT> out_vec(output->numel());
for (int i = 0; i < index_size; i++) { for (int i = 0; i < index_size; i++) {
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
index_vec[i], index_vec[i],
0, 0,
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) " "Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
...@@ -119,7 +310,7 @@ void IndexSelect(const framework::ExecutionContext& context, ...@@ -119,7 +310,7 @@ void IndexSelect(const framework::ExecutionContext& context,
PADDLE_ENFORCE_LT( PADDLE_ENFORCE_LT(
index_vec[i], index_vec[i],
input_dim[dim], input_dim[dim],
platform::errors::InvalidArgument( phi::errors::InvalidArgument(
"Variable value (index) of OP(index_select) " "Variable value (index) of OP(index_select) "
"expected >= 0 and < %ld, but got %ld. Please check input " "expected >= 0 and < %ld, but got %ld. Please check input "
"value.", "value.",
...@@ -139,162 +330,21 @@ void IndexSelect(const framework::ExecutionContext& context, ...@@ -139,162 +330,21 @@ void IndexSelect(const framework::ExecutionContext& context,
} }
} }
} }
output->mutable_data<InT>(context.GetPlace()); context.template Alloc<InT>(output);
framework::TensorFromVector(out_vec, context.device_context(), output); paddle::framework::TensorFromVector(out_vec, context, output);
output->Resize(output_dim); output->Resize(output_dim);
} }
// The core logic of computing Unique Consecutive for a flattend Tensor
template <typename InT, typename IndexT, typename equal_T, typename not_equal_T>
static void UniqueConsecutiveFlattendCUDATensor(
const framework::ExecutionContext& context,
const Tensor& in,
Tensor* out,
bool return_inverse,
bool return_counts,
equal_T equal,
not_equal_T not_equal,
int64_t num_input) {
// 0. Prepration
Tensor in_hat;
framework::TensorCopy(in, context.GetPlace(), &in_hat);
auto in_data_hat = in_hat.mutable_data<InT>(context.GetPlace());
Tensor sorted_indices;
sorted_indices.Resize(phi::make_ddim({num_input}));
auto sorted_indices_data =
sorted_indices.mutable_data<IndexT>(context.GetPlace());
thrust::sequence(
thrust::device, sorted_indices_data, sorted_indices_data + num_input);
// 1. Calculate op result: 'out'
Tensor range;
range.Resize(phi::make_ddim({num_input + 1}));
auto range_data_ptr = range.mutable_data<IndexT>(context.GetPlace());
thrust::sequence(
thrust::device, range_data_ptr, range_data_ptr + num_input + 1);
framework::TensorCopy(in_hat, context.GetPlace(), out);
int num_out;
auto out_data = out->mutable_data<InT>(context.GetPlace());
num_out =
thrust::unique_by_key(
thrust::device, out_data, out_data + num_input, range_data_ptr, equal)
.first -
out_data;
out->Resize(phi::make_ddim({num_out}));
// 2. Calculate inverse index: 'inverse'
if (return_inverse) {
Tensor* inverse = context.Output<Tensor>("Index");
inverse->Resize(phi::make_ddim({num_input}));
auto inverse_data = inverse->mutable_data<IndexT>(context.GetPlace());
Tensor inv_loc;
inv_loc.Resize(phi::make_ddim({num_input}));
auto inv_loc_data_ptr = inv_loc.mutable_data<IndexT>(context.GetPlace());
thrust::adjacent_difference(thrust::device,
in_data_hat,
in_data_hat + num_input,
inv_loc_data_ptr,
not_equal);
thrust::device_ptr<IndexT> inv_loc_data_dev(inv_loc_data_ptr);
inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault
thrust::inclusive_scan(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + num_input,
inv_loc_data_ptr);
thrust::scatter(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + num_input,
sorted_indices_data,
inverse_data);
}
// 3. Calculate 'counts'
if (return_counts) {
Tensor* counts = context.Output<Tensor>("Counts");
counts->Resize(phi::make_ddim({num_out}));
auto count_data = counts->mutable_data<IndexT>(context.GetPlace());
// init 'count_data' as 0
thrust::fill(thrust::device, count_data, count_data + num_out, 0);
thrust::device_ptr<IndexT> range_data_ptr_dev(range_data_ptr);
range_data_ptr_dev[num_out] = num_input;
thrust::adjacent_difference(thrust::device,
range_data_ptr + 1,
range_data_ptr + num_out + 1,
count_data);
}
}
// The logic of compute unique with axis required, it's a little different
// from above function
template <typename InT, typename IndexT, typename equal_T, typename not_equal_T>
static void ComputeUniqueConsecutiveDims(
const framework::ExecutionContext& context,
Tensor* sorted_indices,
IndexT* sorted_indices_data,
Tensor* out,
bool return_inverse,
bool return_counts,
equal_T equal,
not_equal_T not_equal,
int64_t row) {
// 1. inverse indices: 'inverse'
Tensor* inverse = context.Output<Tensor>("Index");
inverse->Resize(phi::make_ddim({row}));
auto inverse_data = inverse->mutable_data<IndexT>(context.GetPlace());
Tensor inv_loc;
inv_loc.Resize(phi::make_ddim({row}));
auto inv_loc_data_ptr = inv_loc.mutable_data<IndexT>(context.GetPlace());
thrust::adjacent_difference(thrust::device,
sorted_indices_data,
sorted_indices_data + row,
inv_loc_data_ptr,
not_equal);
thrust::device_ptr<IndexT> inv_loc_data_dev(inv_loc_data_ptr);
inv_loc_data_dev[0] = 0;
thrust::inclusive_scan(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + row,
inv_loc_data_ptr);
thrust::scatter(thrust::device,
inv_loc_data_ptr,
inv_loc_data_ptr + row,
sorted_indices_data,
inverse_data);
// 2. sorted indices
Tensor range;
range.Resize(phi::make_ddim({row + 1}));
auto range_data_ptr = range.mutable_data<IndexT>(context.GetPlace());
thrust::sequence(thrust::device, range_data_ptr, range_data_ptr + row + 1);
int num_out;
num_out = thrust::unique_by_key(thrust::device,
sorted_indices_data,
sorted_indices_data + row,
range_data_ptr,
equal)
.first -
sorted_indices_data;
thrust::device_ptr<IndexT> range_data_ptr_dev(range_data_ptr);
range_data_ptr_dev[num_out] = row;
sorted_indices->Resize(phi::make_ddim({num_out}));
// 3. counts: 'counts'
Tensor* counts = context.Output<Tensor>("Counts");
counts->Resize(phi::make_ddim({num_out}));
auto count_data = counts->mutable_data<IndexT>(context.GetPlace());
thrust::fill(thrust::device, count_data, count_data + row, 0);
thrust::adjacent_difference(
thrust::device, range_data_ptr + 1, range_data_ptr + row + 1, count_data);
}
// Calculate unique consecutive when 'axis' is set // Calculate unique consecutive when 'axis' is set
template <typename DeviceContext, typename InT, typename IndexT> template <typename Context, typename InT, typename IndexT>
static void UniqueConsecutiveDimsCUDATensor( static void UniqueConsecutiveDimsCUDATensor(const Context& context,
const framework::ExecutionContext& context, const DenseTensor& in,
const Tensor& in, DenseTensor* out,
Tensor* out, bool return_inverse,
bool return_inverse, bool return_counts,
bool return_counts, int axis,
int axis) { DenseTensor* inverse,
DenseTensor* counts) {
// 1. Transpose & reshape // 1. Transpose & reshape
// Transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2] // Transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2]
std::vector<int> permute(in.dims().size()); std::vector<int> permute(in.dims().size());
...@@ -304,19 +354,18 @@ static void UniqueConsecutiveDimsCUDATensor( ...@@ -304,19 +354,18 @@ static void UniqueConsecutiveDimsCUDATensor(
std::vector<int64_t> in_trans_dims_vec(phi::vectorize(in.dims())); std::vector<int64_t> in_trans_dims_vec(phi::vectorize(in.dims()));
in_trans_dims_vec[axis] = in.dims()[0]; in_trans_dims_vec[axis] = in.dims()[0];
in_trans_dims_vec[0] = in.dims()[axis]; in_trans_dims_vec[0] = in.dims()[axis];
framework::Tensor in_trans; DenseTensor in_trans;
framework::DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec); DDim in_trans_dims = phi::make_ddim(in_trans_dims_vec);
in_trans.Resize(in_trans_dims); in_trans.Resize(in_trans_dims);
in_trans.mutable_data<InT>(context.GetPlace()); context.template Alloc<InT>(&in_trans);
auto& dev_ctx = context.cuda_device_context(); phi::funcs::TransCompute<Context, InT>(in.dims().size(), // num of dims
TransCompute<DeviceContext, InT>(in.dims().size(), // num of dims context, // device
dev_ctx, // device in, // original Tensor
in, // original Tensor &in_trans, // Tensor after reshape
&in_trans, // Tensor after reshape permute); // index of axis
permute); // index of axis
// Reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] // Reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2]
framework::DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1); DDim in_trans_flat_dims = phi::flatten_to_2d(in_trans_dims, 1);
in_trans.Resize(in_trans_flat_dims); in_trans.Resize(in_trans_flat_dims);
// now 'in_trans' is 2D // now 'in_trans' is 2D
...@@ -324,16 +373,15 @@ static void UniqueConsecutiveDimsCUDATensor( ...@@ -324,16 +373,15 @@ static void UniqueConsecutiveDimsCUDATensor(
int64_t row = in_trans.dims()[0]; int64_t row = in_trans.dims()[0];
const InT* in_trans_data = in_trans.data<InT>(); const InT* in_trans_data = in_trans.data<InT>();
Tensor sorted_indices; DenseTensor sorted_indices;
sorted_indices.Resize(phi::make_ddim({row})); sorted_indices.Resize(phi::make_ddim({row}));
auto sorted_indices_data = auto sorted_indices_data = context.template Alloc<IndexT>(&sorted_indices);
sorted_indices.mutable_data<IndexT>(context.GetPlace());
// 2. Calculate 'inverse', 'counts' // 2. Calculate 'inverse', 'counts'
// Init index // Init index
thrust::sequence( thrust::sequence(
thrust::device, sorted_indices_data, sorted_indices_data + row); thrust::device, sorted_indices_data, sorted_indices_data + row);
ComputeUniqueConsecutiveDims<InT, IndexT>( ComputeUniqueConsecutiveDims<Context, InT, IndexT>(
context, context,
&sorted_indices, &sorted_indices,
sorted_indices_data, sorted_indices_data,
...@@ -342,143 +390,70 @@ static void UniqueConsecutiveDimsCUDATensor( ...@@ -342,143 +390,70 @@ static void UniqueConsecutiveDimsCUDATensor(
return_counts, return_counts,
BinaryEqual<InT>(col, in_trans_data), BinaryEqual<InT>(col, in_trans_data),
BinaryNotEqual<InT>(col, in_trans_data), BinaryNotEqual<InT>(col, in_trans_data),
row); row,
inverse,
counts);
// 3. Select indices and reshape back to get 'out' // 3. Select indices and reshape back to get 'out'
Tensor out_trans; DenseTensor out_trans;
std::vector<int64_t> out_trans_dims_vec = in_trans_dims_vec; std::vector<int64_t> out_trans_dims_vec = in_trans_dims_vec;
out_trans_dims_vec[0] = sorted_indices.numel(); out_trans_dims_vec[0] = sorted_indices.numel();
out_trans.Resize(phi::make_ddim(out_trans_dims_vec)); out_trans.Resize(phi::make_ddim(out_trans_dims_vec));
out_trans.mutable_data<InT>(context.GetPlace()); context.template Alloc<InT>(&out_trans);
IndexSelect<InT, IndexT>(context, in_trans, sorted_indices, &out_trans, 0); IndexSelect<Context, InT, IndexT>(
context, in_trans, sorted_indices, &out_trans, 0);
std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]);
out->Resize(phi::make_ddim(out_trans_dims_vec)); out->Resize(phi::make_ddim(out_trans_dims_vec));
out->mutable_data<InT>(context.GetPlace()); context.template Alloc<InT>(out);
std::vector<framework::Tensor> out_trans_unbind = Unbind(out_trans); std::vector<DenseTensor> out_trans_unbind = phi::funcs::Unbind(out_trans);
math::ConcatFunctor<DeviceContext, InT> concat_functor; phi::funcs::ConcatFunctor<Context, InT> concat_functor;
concat_functor(dev_ctx, out_trans_unbind, 0, &out_trans); concat_functor(context, out_trans_unbind, 0, &out_trans);
TransCompute<DeviceContext, InT>( phi::funcs::TransCompute<Context, InT>(
out_trans.dims().size(), dev_ctx, out_trans, out, permute); out_trans.dims().size(), context, out_trans, out, permute);
} }
// functor for processing a flattend Tensor
template <typename DeviceContext, typename InT>
struct UniqueConsecutiveFlattendCUDAFunctor {
const framework::ExecutionContext& ctx_;
const Tensor& in_;
Tensor* out_;
const bool return_inverse_;
const bool return_counts_;
UniqueConsecutiveFlattendCUDAFunctor(
const framework::ExecutionContext& context,
const Tensor& in,
Tensor* out,
bool return_inverse,
bool return_counts)
: ctx_(context),
in_(in),
out_(out),
return_inverse_(return_inverse),
return_counts_(return_counts) {}
template <typename IndexT>
void apply() const {
UniqueConsecutiveFlattendCUDATensor<InT, IndexT>(
ctx_,
in_,
out_,
return_inverse_,
return_counts_,
thrust::equal_to<InT>(),
thrust::not_equal_to<InT>(),
in_.numel());
}
};
// functor for processing a multi-dimentional Tensor // functor for processing a multi-dimentional Tensor
template <typename DeviceContext, typename InT> template <typename Context, typename InT>
struct UniqueConsecutiveDimsCUDAFunctor { struct UniqueConsecutiveDimsCUDAFunctor {
const framework::ExecutionContext& ctx_; const Context& ctx_;
const Tensor& in_; const DenseTensor& in_;
Tensor* out_; DenseTensor* out_;
const int axis_; const int axis_;
const bool return_inverse_; const bool return_inverse_;
const bool return_counts_; const bool return_counts_;
DenseTensor* inverse_;
DenseTensor* count_;
UniqueConsecutiveDimsCUDAFunctor(const framework::ExecutionContext& context, UniqueConsecutiveDimsCUDAFunctor(const Context& context,
const Tensor& in, const DenseTensor& in,
Tensor* out, DenseTensor* out,
const int axis, const int axis,
bool return_inverse, bool return_inverse,
bool return_counts) bool return_counts,
DenseTensor* inverse,
DenseTensor* count)
: ctx_(context), : ctx_(context),
in_(in), in_(in),
out_(out), out_(out),
axis_(axis), axis_(axis),
return_inverse_(return_inverse), return_inverse_(return_inverse),
return_counts_(return_counts) {} return_counts_(return_counts),
inverse_(inverse),
count_(count) {}
template <typename IndexT> template <typename IndexT>
void apply() const { void apply() const {
UniqueConsecutiveDimsCUDATensor<DeviceContext, InT, IndexT>( UniqueConsecutiveDimsCUDATensor<Context, InT, IndexT>(ctx_,
ctx_, in_, out_, return_inverse_, return_counts_, axis_); in_,
out_,
return_inverse_,
return_counts_,
axis_,
inverse_,
count_);
} }
}; };
// Unique_Consecutive_op CUDA implementation. } // namespace phi
template <typename InT>
class UniqueConsecutiveKernel<platform::CUDADeviceContext, InT>
: public framework::OpKernel<InT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
auto data_type = static_cast<framework::proto::VarType::Type>(
context.Attr<int>("dtype"));
if (data_type == framework::proto::VarType::INT32) {
PADDLE_ENFORCE_LE(
x->numel() + 1,
INT_MAX,
platform::errors::InvalidArgument(
"The number of elements in Input(X) should be less than or "
"equal to INT_MAX, but received num is %d. Please set `dtype` to "
"int64.",
x->numel()));
}
std::vector<int> axis_vec = context.Attr<std::vector<int>>("axis");
bool return_inverse = context.Attr<bool>("return_inverse");
bool return_counts = context.Attr<bool>("return_counts");
// if 'axis' is not required, flatten the Tensor.
if (axis_vec.empty()) {
framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattendCUDAFunctor<platform::CUDADeviceContext,
InT>(
context, *x, out, return_inverse, return_counts));
} else {
// 'axis' is required.
int axis = axis_vec[0];
framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimsCUDAFunctor<platform::CUDADeviceContext, InT>(
context, *x, out, axis, return_inverse, return_counts));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
unique_consecutive,
ops::UniqueConsecutiveKernel<paddle::platform::CUDADeviceContext, float>,
ops::UniqueConsecutiveKernel<paddle::platform::CUDADeviceContext, double>,
ops::UniqueConsecutiveKernel<paddle::platform::CUDADeviceContext, int32_t>,
ops::UniqueConsecutiveKernel<paddle::platform::CUDADeviceContext, int64_t>);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/kernels/unique_consecutive_kernel.h"
#include "paddle/phi/kernels/gpu/unique_consecutive_functor.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/data_type.h"
namespace phi {
template <typename T, typename Context>
void UniqueConsecutiveKernel(const Context& dev_ctx,
const DenseTensor& x,
bool return_inverse,
bool return_counts,
const std::vector<int>& axis,
int dtype,
DenseTensor* out,
DenseTensor* index,
DenseTensor* counts) {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(dtype);
if (data_type == paddle::framework::proto::VarType::INT32) {
PADDLE_ENFORCE_LE(
x.numel() + 1,
INT_MAX,
phi::errors::InvalidArgument(
"The number of elements in Input(X) should be less than or "
"equal to INT_MAX, but received num is %d. Please set `dtype` to "
"int64.",
x.numel()));
}
// if 'axis' is not required, flatten the Tensor.
if (axis.empty()) {
paddle::framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattenedCUDAFunctor<Context, T>(
dev_ctx, x, out, return_inverse, return_counts, index, counts));
} else {
// 'axis' is required.
int valid_axis = axis[0];
paddle::framework::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimsCUDAFunctor<Context, T>(dev_ctx,
x,
out,
valid_axis,
return_inverse,
return_counts,
index,
counts));
}
}
} // namespace phi
PD_REGISTER_KERNEL(unique_consecutive,
GPU,
ALL_LAYOUT,
phi::UniqueConsecutiveKernel,
float,
double,
int32_t,
int64_t) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void UniqueConsecutiveKernel(const Context& dev_ctx,
const DenseTensor& x,
bool return_inverse,
bool return_counts,
const std::vector<int>& axis,
int dtype,
DenseTensor* out,
DenseTensor* index,
DenseTensor* counts);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature UniqueConsecutiveOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature("unique_consecutive",
{"X"},
{"return_inverse", "return_counts", "axis", "dtype"},
{"Out", "Index", "Counts"});
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(unique_consecutive,
phi::UniqueConsecutiveOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册