From f9f3bc21692a359a40bde1368667e9da5b928f67 Mon Sep 17 00:00:00 2001 From: ashburnlee <1578034415@qq.com> Date: Thu, 24 Sep 2020 04:35:58 +0000 Subject: [PATCH] unique op for cuda is added --- paddle/fluid/operators/unique_op.cu | 372 ++++++++++++++++++++++++++++ 1 file changed, 372 insertions(+) create mode 100644 paddle/fluid/operators/unique_op.cu diff --git a/paddle/fluid/operators/unique_op.cu b/paddle/fluid/operators/unique_op.cu new file mode 100644 index 00000000000..522e2f759b7 --- /dev/null +++ b/paddle/fluid/operators/unique_op.cu @@ -0,0 +1,372 @@ +/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#include +#include +#include +#include +#include +#include +#include +#include "paddle/fluid/operators/unique_op.h" // TransComute + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +// Binary function 'less than' +template +struct LessThan { + int col; + const InT* in_trans_data; + + LessThan(int64_t _col, const InT* _in_trans_data) + : col(_col), in_trans_data(_in_trans_data) {} + + __device__ bool operator()(int64_t a, int64_t b) const { + for (int i = 0; i < col; ++i) { + InT lhs = in_trans_data[i + a * col]; + InT rhs = in_trans_data[i + b * col]; + if (lhs < rhs) { + return true; + } else if (lhs > rhs) { + return false; + } + } + return false; + } +}; + +// Binary function 'equal_to' +template +struct BinaryEqual { + int64_t col; + const InT* in_trans_data; + + BinaryEqual(int64_t _col, const InT* _in_trans_data) + : col(_col), in_trans_data(_in_trans_data) {} + + __device__ bool operator()(int64_t a, int64_t b) const { + for (int64_t i = 0; i < col; ++i) { + InT lhs = in_trans_data[i + a * col]; + InT rhs = in_trans_data[i + b * col]; + if (lhs != rhs) { + return false; + } + } + return true; + } +}; + +// Binary function 'not_equal_to' +template +struct BinaryNotEqual { + int64_t col; + const InT* in_trans_data; + + BinaryNotEqual(int64_t _col, const InT* _in_trans_data) + : col(_col), in_trans_data(_in_trans_data) {} + + __device__ int64_t operator()(int64_t a, int64_t b) const { + for (int64_t i = 0; i < col; ++i) { + InT lhs = in_trans_data[i + a * col]; + InT rhs = in_trans_data[i + b * col]; + if (lhs != rhs) { + return 1; + } + } + return 0; + } +}; + +/// The core logic of computing Unique +template +static void ComputeUniqueFlatten(const framework::ExecutionContext& context, + const framework::Tensor& in, + framework::Tensor* out, bool return_index, + bool return_inverse, bool return_counts, + equal_T equal, not_equal_T not_equal, + int64_t num_input) { + // 0. Prepration + Tensor in_hat; + framework::TensorCopy(in, context.GetPlace(), &in_hat); + auto in_data_hat = in_hat.mutable_data(context.GetPlace()); + + Tensor* sorted_indices = context.Output("Indices"); + sorted_indices->Resize(framework::make_ddim({num_input})); + auto sorted_indices_data = + sorted_indices->mutable_data(context.GetPlace()); + thrust::sequence(thrust::device, sorted_indices_data, + sorted_indices_data + num_input); + thrust::sort_by_key(thrust::device, in_data_hat, in_data_hat + num_input, + sorted_indices_data); + + // 1. Calculate op result: 'out': + Tensor range; + range.Resize(framework::make_ddim({num_input + 1})); + auto range_data_ptr = range.mutable_data(context.GetPlace()); + thrust::sequence(thrust::device, range_data_ptr, + range_data_ptr + num_input + 1); + framework::TensorCopy(in_hat, context.GetPlace(), out); + int num_out; + auto out_data = out->mutable_data(context.GetPlace()); + num_out = thrust::unique_by_key(thrust::device, out_data, + out_data + num_input, range_data_ptr, equal) + .first - + out_data; + out->Resize(framework::make_ddim({num_out})); + + // 3. Calculate inverse index: 'inverse' + if (return_inverse) { + Tensor* inverse = context.Output("Index"); + inverse->Resize(framework::make_ddim({num_input})); + auto inverse_data = inverse->mutable_data(context.GetPlace()); + Tensor inv_loc; + inv_loc.Resize(framework::make_ddim({num_input})); + auto inv_loc_data_ptr = inv_loc.mutable_data(context.GetPlace()); + thrust::adjacent_difference(thrust::device, in_data_hat, + in_data_hat + num_input, inv_loc_data_ptr, + not_equal); + thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); + inv_loc_data_dev[0] = 0; // without device_ptr, segmentation fault + thrust::inclusive_scan(thrust::device, inv_loc_data_ptr, + inv_loc_data_ptr + num_input, inv_loc_data_ptr); + thrust::scatter(thrust::device, inv_loc_data_ptr, + inv_loc_data_ptr + num_input, sorted_indices_data, + inverse_data); + } + + // 2. Calculate sorted index: 'sorted_indices' + if (return_index) { + Tensor indices; + indices.Resize(framework::make_ddim({num_input})); + auto indices_data_ptr = indices.mutable_data(context.GetPlace()); + thrust::copy(thrust::device, in_data_hat, in_data_hat + num_input, + indices_data_ptr); + thrust::unique_by_key(thrust::device, indices_data_ptr, + indices_data_ptr + num_input, sorted_indices_data, + equal); + sorted_indices->Resize(framework::make_ddim({num_out})); + } + + // 4. Calculate 'counts' + if (return_counts) { + Tensor* counts = context.Output("Counts"); + counts->Resize(framework::make_ddim({num_out})); + auto count_data = counts->mutable_data(context.GetPlace()); + // init 'count_data' as 0 + thrust::fill(thrust::device, count_data, count_data + num_out, 0); + thrust::device_ptr range_data_ptr_dev(range_data_ptr); + range_data_ptr_dev[num_out] = num_input; + thrust::adjacent_difference(thrust::device, range_data_ptr + 1, + range_data_ptr + num_out + 1, count_data); + } +} + +// The logic of compute unique with axis required, it's a little different +// from above function +template +static void ComputeUniqueDims(const framework::ExecutionContext& context, + framework::Tensor* sorted_indices, + InT* sorted_indices_data, framework::Tensor* out, + bool return_index, bool return_inverse, + bool return_counts, equal_T equal, + not_equal_T not_equal, int64_t row) { + // 1. inverse indices: 'inverse' + Tensor* inverse = context.Output("Index"); + inverse->Resize(framework::make_ddim({row})); /// in.shape[0] + auto inverse_data = inverse->mutable_data(context.GetPlace()); + Tensor inv_loc; + inv_loc.Resize(framework::make_ddim({row})); + auto inv_loc_data_ptr = inv_loc.mutable_data(context.GetPlace()); + thrust::adjacent_difference(thrust::device, sorted_indices_data, + sorted_indices_data + row, inv_loc_data_ptr, + not_equal); + thrust::device_ptr inv_loc_data_dev(inv_loc_data_ptr); + inv_loc_data_dev[0] = 0; + thrust::inclusive_scan(thrust::device, inv_loc_data_ptr, + inv_loc_data_ptr + row, inv_loc_data_ptr); + thrust::scatter(thrust::device, inv_loc_data_ptr, inv_loc_data_ptr + row, + sorted_indices_data, inverse_data); + + // 2. sorted indices + Tensor range; + range.Resize(framework::make_ddim({row + 1})); + auto range_data_ptr = range.mutable_data(context.GetPlace()); + thrust::sequence(thrust::device, range_data_ptr, range_data_ptr + row + 1); + int num_out; + num_out = + thrust::unique_by_key(thrust::device, sorted_indices_data, + sorted_indices_data + row, range_data_ptr, equal) + .first - + sorted_indices_data; + thrust::device_ptr range_data_ptr_dev(range_data_ptr); + range_data_ptr_dev[num_out] = row; + + // 3. counts: 'counts' + Tensor* counts = context.Output("Counts"); + counts->Resize(framework::make_ddim({row})); + auto count_data = counts->mutable_data(context.GetPlace()); + thrust::fill(thrust::device, count_data, count_data + row, 0); + thrust::adjacent_difference(thrust::device, range_data_ptr + 1, + range_data_ptr + row + 1, count_data); + + /** + * TODO(ashburnlee) implement index_select() to get 'out' and reshape back + */ +} + +// Calculate unique when 'dim' is not set +template +static void UniqueFlattendCUDATensor(const framework::ExecutionContext& context, + const framework::Tensor& in, + framework::Tensor* out, bool return_index, + bool return_inverse, bool return_counts) { + ComputeUniqueFlatten(context, in, out, return_index, return_inverse, + return_counts, thrust::equal_to(), + thrust::not_equal_to(), in.numel()); +} + +// Calculate unique when 'dim' is set +template +static void UniqueDimsCUDATensor(const framework::ExecutionContext& context, + const framework::Tensor& in, + framework::Tensor* out, bool return_index, + bool return_inverse, bool return_counts, + int axis) { + // Transpose & reshape + // Transpose tensor: eg. axis=1, [dim0, dim1, dim2] -> [dim1, dim0, dim2] + std::vector permute(in.dims().size()); + std::iota(permute.begin(), permute.end(), 0); + permute[axis] = 0; + permute[0] = axis; + std::vector in_trans_dims_vec(framework::vectorize(in.dims())); + in_trans_dims_vec[axis] = in.dims()[0]; + in_trans_dims_vec[0] = in.dims()[axis]; + framework::Tensor in_trans; + framework::DDim in_trans_dims = framework::make_ddim(in_trans_dims_vec); + in_trans.Resize(in_trans_dims); + in_trans.mutable_data(context.GetPlace()); + auto& dev_ctx = context.cuda_device_context(); + TransCompute(in.dims().size(), // 维度个数 + dev_ctx, // 设备 + in, // 原始tensor + &in_trans, // Reshape 后的tensor 被修改 + permute); // axis 的索引 + + // Reshape tensor: eg. [dim1, dim0, dim2] -> [dim1, dim0*dim2] + framework::DDim in_trans_flat_dims = + framework::flatten_to_2d(in_trans_dims, 1); + in_trans.Resize(in_trans_flat_dims); + + // in_trans 2D + // in_trans(unsorted) as 'in' + int64_t col = in_trans.dims()[1]; + int64_t row = in_trans.dims()[0]; + const InT* in_trans_data = in_trans.data(); + + // Tensor in_trans_hat; + // framework::TensorCopy(in_trans, context.GetPlace(), &in_trans_hat); + auto in_trans_data = in_trans.mutable_data(context.GetPlace()); + Tensor* sorted_indices = context.Output("Indices"); + sorted_indices->Resize(framework::make_ddim({row})); + auto sorted_indices_data = + sorted_indices->mutable_data(context.GetPlace()); + + // Init index and sort + thrust::sequence(thrust::device, sorted_indices_data, + sorted_indices_data + row); + thrust::sort(thrust::device, sorted_indices_data, sorted_indices_data + row, + LessThan(col, in_trans_data)); + + ComputeUniqueDims(context, sorted_indices, sorted_indices_data, out, + return_index, return_inverse, return_counts, + BinaryEqual(col, in_trans_data), + BinaryNotEqual(col, in_trans_data), row); + + /** + * NOTE: If index_select() is implemented and called in ComputeUniqueDims(), + * the code below can be deleted. + */ + + // Reshape 'out' back + std::vector in_trans_unbind = Unbind(in_trans_hat); + math::ConcatFunctor concat_functor; + framework::Tensor out_trans; + std::vector out_trans_dims_vec = in_trans_dims_vec; + out_trans_dims_vec[0] = in_trans_unbind.size(); + out_trans.Resize(framework::make_ddim(out_trans_dims_vec)); + out_trans.mutable_data(context.GetPlace()); + std::swap(out_trans_dims_vec[0], out_trans_dims_vec[axis]); + out->Resize(framework::make_ddim(out_trans_dims_vec)); + out->mutable_data(context.GetPlace()); + + concat_functor(dev_ctx, in_trans_unbind, 0, &out_trans); + TransCompute(out_trans.dims().size(), dev_ctx, out_trans, + out, permute); +} + +// Unique_op CUDA implementation. +template +class UniqueKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* out = context.Output("Out"); + auto data_type = static_cast( + context.Attr("dtype")); + if (data_type == framework::proto::VarType::INT32) { + PADDLE_ENFORCE_LE( + x->numel() + 1, INT_MAX, + platform::errors::InvalidArgument( + "The number of elements in Input(X) should be less than or " + "equal to INT_MAX, but received num is %d. Please set `dtype` to " + "int64.", + x->numel())); + } + + if (!context.Attr("is_sorted")) { + auto* index = context.Output("Index"); + // 历史版本 + // TODO(ashburnlee) + return; + } + + std::vector axis_vec = context.Attr>("axis"); + bool return_index = context.Attr("return_index"); + bool return_inverse = context.Attr("return_inverse"); + bool return_counts = context.Attr("return_counts"); + + if (axis_vec.empty()) { + UniqueFlattendCUDATensor(context, *x, out, return_index, + return_inverse, return_counts); + } else { + int axis = axis_vec[0]; + // 已指明 DeviceContext 为 CUDADeviceContext, 写法正确 + UniqueDimsCUDATensor( + context, *x, out, return_index, return_inverse, return_counts, axis); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_CUDA_KERNEL( + unique, ops::UniqueKernel, + ops::UniqueKernel, + ops::UniqueKernel, + ops::UniqueKernel); -- GitLab