From 98460c009eb6a18339097b8ef9be43a216ce1e5f Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Tue, 12 Jun 2018 09:31:10 -0700 Subject: [PATCH] Simplify the computation in cpu --- paddle/fluid/operators/argsort_op.h | 51 +++++++++++++++-------------- 1 file changed, 27 insertions(+), 24 deletions(-) diff --git a/paddle/fluid/operators/argsort_op.h b/paddle/fluid/operators/argsort_op.h index e13745c4941..7e9112cfb7c 100644 --- a/paddle/fluid/operators/argsort_op.h +++ b/paddle/fluid/operators/argsort_op.h @@ -28,47 +28,50 @@ class ArgsortKernel : public framework::OpKernel { auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); - int axis = static_cast(ctx.Attr("axis")); + int axis = ctx.Attr("axis"); auto in_dims = input->dims(); axis = (axis < 0) ? (in_dims.size() + axis) : axis; const T* in_data = input->data(); T* out_data = output->mutable_data(ctx.GetPlace()); - int64_t* idx_data = indices->mutable_data(ctx.GetPlace()); + int64_t* ids_data = indices->mutable_data(ctx.GetPlace()); - int64_t part_dims_prod = input->numel() / in_dims[axis]; - for (int64_t i = 0; i < part_dims_prod; ++i) { + int64_t groups = input->numel() / in_dims[axis]; + int64_t stride = (axis == in_dims.size() - 1) + ? 1 + : framework::product(framework::slice_ddim( + in_dims, axis + 1, in_dims.size())); + + for (int64_t i = 0; i < groups; ++i) { int64_t idx = i; - std::vector idx_vec(in_dims.size(), 0); + std::vector shape_vec(in_dims.size(), 0); for (int64_t dim = in_dims.size() - 1; dim >= 0; --dim) { if (dim != axis) { - idx_vec[dim] = idx % in_dims[dim]; + shape_vec[dim] = idx % in_dims[dim]; idx /= in_dims[dim]; } } - std::vector> in_vec; - std::vector org_index_vec(in_dims[axis], 0); - for (int64_t j = 0; j < in_dims[axis]; ++j) { - idx_vec[axis] = j; - int64_t index = idx_vec[0]; - for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) { - index = index * in_dims[dim + 1] + idx_vec[dim + 1]; - } - in_vec.push_back(std::pair(in_data[index], j)); - org_index_vec[j] = index; + + int64_t start_index = shape_vec[0]; + for (int64_t dim = 0; dim < in_dims.size() - 1; ++dim) { + start_index = start_index * in_dims[dim + 1] + shape_vec[dim + 1]; + } + + std::vector org_index_vec(in_dims[axis], start_index); + for (int64_t j = 1; j < in_dims[axis]; ++j) { + org_index_vec[j] += j * stride; } - std::sort( - in_vec.begin(), in_vec.end(), - [](const std::pair& v1, const std::pair& v2) { - return v1.first < v2.first; - }); + std::sort(org_index_vec.begin(), org_index_vec.end(), + [in_data](const int64_t v1, const int64_t v2) { + return in_data[v1] < in_data[v2]; + }); for (size_t j = 0; j < org_index_vec.size(); ++j) { - int64_t index = org_index_vec[j]; - out_data[index] = in_vec[j].first; - idx_data[index] = in_vec[j].second; + int64_t index = start_index + j * stride; + out_data[index] = in_data[org_index_vec[j]]; + ids_data[index] = (org_index_vec[j] - start_index) / stride; } } } -- GitLab