未验证 提交 538b5721 编写于 作者: J JYChen 提交者: GitHub

[new API] add paddle.kthvalue and paddle.Tensor.kthvalue (#38386)

* add new api/op kthvalue

* kthvalue cuda kernel to cub sorting

* fix example code error

* throw errors instead of LOG in cuda sort

* throw errors by Paddle_ENFORCE
上级 bc827307
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/kthvalue_op.h"
#include <memory>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
namespace paddle {
namespace operators {
class KthvalueOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "kthvalue");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "kthvalue");
OP_INOUT_CHECK(ctx->HasOutput("Indices"), "Output", "Indices", "kthvalue");
auto input_dims = ctx->GetInputDim("X");
const int& dim_size = input_dims.size();
int axis = static_cast<int>(ctx->Attrs().Get<int>("axis"));
PADDLE_ENFORCE_LT(axis, dim_size,
paddle::platform::errors::InvalidArgument(
"the axis must be [-%d, %d), but received %d .",
dim_size, dim_size, axis));
PADDLE_ENFORCE_GE(axis, -dim_size,
paddle::platform::errors::InvalidArgument(
"the axis must be [-%d, %d), but received %d .",
dim_size, dim_size, axis));
if (axis < 0) axis += dim_size;
int k = static_cast<int>(ctx->Attrs().Get<int>("k"));
PADDLE_ENFORCE_GE(
k, 1, paddle::platform::errors::InvalidArgument(
"the k in the kthvalue must >= 1, but received %d .", k));
PADDLE_ENFORCE_GE(input_dims.size(), 1,
paddle::platform::errors::InvalidArgument(
"input of kthvalue must have >= 1d shape"));
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_GE(
input_dims[axis], k,
paddle::platform::errors::InvalidArgument(
"input of kthvalue must have >= %d columns in axis of %d", k,
axis));
}
bool keepdim = ctx->Attrs().Get<bool>("keepdim");
std::vector<int64_t> dimvec;
for (int64_t i = 0; i < axis; i++) {
dimvec.emplace_back(input_dims[i]);
}
if (keepdim) {
dimvec.emplace_back(static_cast<int64_t>(1));
}
for (int64_t i = axis + 1; i < dim_size; i++) {
dimvec.emplace_back(input_dims[i]);
}
framework::DDim dims = framework::make_ddim(dimvec);
ctx->SetOutputDim("Out", dims);
ctx->SetOutputDim("Indices", dims);
ctx->ShareLoD("X", "Out");
ctx->ShareLoD("X", "Indices");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class KthvalueOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddComment(R"DOC(
This operator find the k-th smallest elements in the specific axis of a Tensor.
It will return the values and corresponding indices.
)DOC");
AddInput("X", "(Tensor) The input of Kthvalue op");
AddOutput("Out", "(Tensor) The values of k-th smallest elements of input");
AddOutput("Indices",
"(Tensor) The indices of k-th smallest elements of input");
AddAttr<int>(
"k",
"(int, default 1) k for k-th smallest elements to look for along "
"the tensor).")
.SetDefault(1);
AddAttr<int>("axis",
"the axis to sort and get the k indices, value."
"if not set, will get k-th value in last axis.")
.SetDefault(-1);
AddAttr<bool>("keepdim", "Keep the dim that to reduce.").SetDefault(false);
}
};
class KthvalueOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should be not null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Indices"), true,
platform::errors::InvalidArgument("Input(Indices) should be not null"));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Grad Input(Out) should be not null"));
PADDLE_ENFORCE_EQ(
ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument("Grad Output(X) should be not null"));
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));
return framework::OpKernelType(data_type, ctx.device_context());
}
};
template <typename T>
class KthvalueGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("kthvalue_grad");
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetInput("X", this->Input("X"));
op->SetInput("Indices", this->Output("Indices"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
op->SetAttrMap(this->Attrs());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(kthvalue, ops::KthvalueOp, ops::KthvalueOpMaker,
ops::KthvalueGradOpMaker<paddle::framework::OpDesc>,
ops::KthvalueGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
kthvalue, ops::KthvalueCPUKernel<paddle::platform::CPUPlace, float>,
ops::KthvalueCPUKernel<paddle::platform::CPUPlace, double>,
ops::KthvalueCPUKernel<paddle::platform::CPUPlace, int32_t>,
ops::KthvalueCPUKernel<paddle::platform::CPUPlace, int64_t>);
REGISTER_OPERATOR(kthvalue_grad, ops::KthvalueOpGrad);
REGISTER_OP_CPU_KERNEL(
kthvalue_grad,
ops::KthvalueGradCPUKernel<paddle::platform::CPUPlace, float>,
ops::KthvalueGradCPUKernel<paddle::platform::CPUPlace, double>,
ops::KthvalueGradCPUKernel<paddle::platform::CPUPlace, int32_t>,
ops::KthvalueGradCPUKernel<paddle::platform::CPUPlace, int64_t>);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/kthvalue_op.h"
#include "paddle/fluid/operators/top_k_function_cuda.h"
#include "paddle/fluid/operators/top_k_v2_op.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
#ifdef __HIPCC__
#include <hipcub/hipcub.hpp>
#endif
namespace paddle {
namespace operators {
int getBlockSize(int col) {
if (col > 512)
return 1024;
else if (col > 256 && col <= 512)
return 512;
else if (col > 128 && col <= 256)
return 256;
else if (col > 64 && col <= 128)
return 128;
else
return 64;
}
template <typename T>
bool SortKthvalue(const platform::CUDADeviceContext& ctx,
const framework::Tensor* input_tensor, const int64_t num_cols,
const int64_t num_rows, const int k,
framework::Tensor* out_tensor,
framework::Tensor* indices_tensor) {
auto cu_stream = ctx.stream();
framework::Tensor input_indices;
const std::vector<int64_t> dims = {num_rows, num_cols};
auto dim = framework::make_ddim(dims);
input_indices.Resize(dim);
input_indices.mutable_data<int64_t>(ctx.GetPlace());
size_t temp_storage_bytes = -1;
int block_size = getBlockSize(num_cols);
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize().x;
unsigned int grid_size = num_rows < maxGridDimX
? static_cast<unsigned int>(num_rows)
: maxGridDimX;
InitIndex<int64_t><<<grid_size, block_size, 0, cu_stream>>>(
input_indices.data<int64_t>(), num_rows, num_cols);
cub::CountingInputIterator<int64_t> counting_iter(0);
cub::TransformInputIterator<int64_t, SegmentOffsetIter,
cub::CountingInputIterator<int64_t>>
segment_offsets_t(counting_iter, SegmentOffsetIter(num_cols));
T* sorted_values_ptr;
int64_t* sorted_indices_ptr;
framework::Tensor temp_values, temp_indices;
const T* input = input_tensor->data<T>();
T* values = out_tensor->data<T>();
int64_t* indices = indices_tensor->mutable_data<int64_t>(ctx.GetPlace());
temp_values.Resize(dim);
temp_indices.Resize(dim);
sorted_values_ptr = temp_values.mutable_data<T>(ctx.GetPlace());
sorted_indices_ptr = temp_indices.mutable_data<int64_t>(ctx.GetPlace());
auto err = cub::DeviceSegmentedRadixSort::SortPairs(
nullptr, temp_storage_bytes, input, sorted_values_ptr,
input_indices.data<int64_t>(), sorted_indices_ptr, num_cols * num_rows,
num_rows, segment_offsets_t, segment_offsets_t + 1, 0, sizeof(T) * 8,
cu_stream);
#ifdef __HIPCC__
if (err != hipSuccess) {
LOG(ERROR) << "KthvalueOP failed as could not launch "
"hipcub::DeviceSegmentedRadixSort::SortPairs, status: "
<< hipGetErrorString(err);
return false;
}
#else
if (err != cudaSuccess) {
LOG(ERROR) << "KthvalueOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairs, status: "
<< cudaGetErrorString(err);
return false;
}
#endif
framework::Tensor temp_storage;
temp_storage.mutable_data<uint8_t>(ctx.GetPlace(), temp_storage_bytes);
err = cub::DeviceSegmentedRadixSort::SortPairs(
temp_storage.data<uint8_t>(), temp_storage_bytes, input,
sorted_values_ptr, input_indices.data<int64_t>(), sorted_indices_ptr,
num_cols * num_rows, num_rows, segment_offsets_t, segment_offsets_t + 1,
0, sizeof(T) * 8, cu_stream);
#ifdef __HIPCC__
if (err != hipSuccess) {
LOG(ERROR) << "KthvalueOP failed as could not launch "
"hipcub::DeviceSegmentedRadixSort::SortPairs, "
<< temp_storage_bytes << ", status: " << hipGetErrorString(err);
return false;
}
#else
if (err != cudaSuccess) {
LOG(ERROR) << "KthvalueOP failed as could not launch "
"cub::DeviceSegmentedRadixSort::SortPairs, "
<< temp_storage_bytes << ", status: " << cudaGetErrorString(err);
return false;
}
#endif
auto& dev = *ctx.eigen_device();
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_indices{0, k - 1};
const Eigen::DSizes<Eigen::DenseIndex, 2> slice_sizes{num_rows, 1};
auto e_indices = framework::EigenMatrix<int64_t>::From(*indices_tensor, dim);
auto e_tmp_indices = framework::EigenMatrix<int64_t>::From(
static_cast<const framework::Tensor>(temp_indices));
std::vector<int> odims = {static_cast<int>(num_rows), static_cast<int>(1)};
dim = framework::make_ddim(odims);
auto e_values = framework::EigenMatrix<T>::From(*out_tensor, dim);
auto e_tmp_values = framework::EigenMatrix<T>::From(
static_cast<const framework::Tensor>(temp_values));
EigenSlice<std::decay_t<decltype(dev)>, int64_t, 2>::Eval(
dev, e_indices, e_tmp_indices, slice_indices, slice_sizes);
EigenSlice<std::decay_t<decltype(dev)>, T, 2>::Eval(
dev, e_values, e_tmp_values, slice_indices, slice_sizes);
return true;
}
template <typename DeviceContext, typename T>
class KthvalueOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
"It must use CUDAPlace, you must check your device set."));
auto* input = ctx.Input<framework::Tensor>("X");
auto* output = ctx.Output<framework::Tensor>("Out");
auto* indices = ctx.Output<framework::Tensor>("Indices");
int k = static_cast<int>(ctx.Attr<int>("k"));
int axis = static_cast<int>(ctx.Attr<int>("axis"));
bool keepdim = static_cast<bool>(ctx.Attr<bool>("keepdim"));
const auto& in_dims = input->dims();
if (axis < 0) axis += in_dims.size();
auto out_dims = output->dims();
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
if (axis == in_dims.size() - 1) {
const int64_t& input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1];
const auto& dev_ctx = ctx.cuda_device_context();
PADDLE_ENFORCE_EQ(SortKthvalue<T>(dev_ctx, input, input_width,
input_height, k, output, indices),
true, platform::errors::External(
"KthvalueOP: Error when use cub sorting"));
return;
} else {
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.emplace_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
if (!keepdim) {
std::vector<int> tmp_out_shape;
for (int i = 0; i < axis; i++) {
tmp_out_shape.emplace_back(in_dims[i]);
}
tmp_out_shape.emplace_back(1);
for (int i = axis + 1; i < in_dims.size(); i++) {
tmp_out_shape.emplace_back(in_dims[i]);
}
framework::DDim tmp_out_dims = framework::make_ddim(tmp_out_shape);
output->Resize(tmp_out_dims);
indices->Resize(tmp_out_dims);
}
framework::DDim trans_dims(in_dims);
framework::DDim trans_out_dims(in_dims);
for (int i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
trans_out_dims[i] = in_dims[trans[i]];
}
trans_out_dims[in_dims.size() - 1] = 1;
framework::Tensor trans_input;
trans_input.mutable_data<T>(trans_dims, ctx.GetPlace());
int ndims = trans.size();
const auto& dev_ctx = ctx.cuda_device_context();
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, *input,
&trans_input, trans);
framework::Tensor trans_ind, trans_out;
trans_ind.mutable_data<int64_t>(trans_out_dims, ctx.GetPlace());
trans_out.mutable_data<T>(trans_out_dims, ctx.GetPlace());
const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
PADDLE_ENFORCE_EQ(
SortKthvalue<T>(dev_ctx, &trans_input, input_width, input_height, k,
&trans_out, &trans_ind),
true,
platform::errors::External("KthvalueOP: Error when use cub sorting"));
TransCompute<platform::CUDADeviceContext, int64_t>(
ndims, dev_ctx, trans_ind, indices, trans);
TransCompute<platform::CUDADeviceContext, T>(ndims, dev_ctx, trans_out,
output, trans);
if (!keepdim) {
output->Resize(out_dims);
indices->Resize(out_dims);
}
}
}
};
template <typename DeviceContext, typename T>
class KthvalueOpGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
PADDLE_ENFORCE_EQ(
platform::is_gpu_place(context.GetPlace()), true,
platform::errors::InvalidArgument(
"It must use CUDAPlace, you must check your device set."));
auto* x = context.Input<framework::Tensor>("X");
auto* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<framework::Tensor>("Indices");
auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
int axis = context.Attr<int>("axis");
int k = static_cast<int>(context.Attr<int>("k"));
const auto& in_dims = x->dims();
auto out_dims = indices->dims();
if (axis < 0) axis += in_dims.size();
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
const T* out_grad_data = out_grad->data<T>();
const int64_t* indices_data = indices->data<int64_t>();
int pre, n, post;
GetDims(in_dims, axis, &pre, &n, &post);
auto& dev_ctx = context.cuda_device_context();
int block_size = getBlockSize(post * k);
int max_threads = dev_ctx.GetMaxPhysicalThreadCount();
const int max_blocks = std::max(((max_threads - 1) / block_size + 1), 1);
int grid_size = std::min(max_blocks, pre);
AssignGradWithAxis<T><<<grid_size, block_size, 64 * 4, dev_ctx.stream()>>>(
out_grad_data, indices_data, x_grad_data, pre, post, n, 1);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
kthvalue,
ops::KthvalueOpCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::KthvalueOpCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::KthvalueOpCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::KthvalueOpCUDAKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
kthvalue_grad,
ops::KthvalueOpGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::KthvalueOpGradCUDAKernel<paddle::platform::CUDADeviceContext, double>,
ops::KthvalueOpGradCUDAKernel<paddle::platform::CUDADeviceContext, int>,
ops::KthvalueOpGradCUDAKernel<paddle::platform::CUDADeviceContext,
int64_t>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <iostream>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/transpose_op.h"
namespace paddle {
namespace operators {
template <typename T, typename Type>
static void getKthvalue(Type input_height, Type input_width, int input_dim,
const framework::Tensor* input, T* t_out,
Type* t_indices, const int& k) {
bool partial_sort_flag = (k * 64) < input_width;
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
std::vector<std::pair<T, Type>> col_vec;
col_vec.reserve(input_width);
if (input_dim == 1) {
auto e_input = framework::EigenVector<T>::Flatten(*input);
for (Type j = 0; j < input_width; ++j) {
col_vec.emplace_back(std::pair<T, Type>(e_input(j), j));
}
} else {
auto e_input = framework::EigenMatrix<T>::Reshape(*input, input_dim - 1);
for (Type j = 0; j < input_width; ++j) {
col_vec.emplace_back(std::pair<T, Type>(e_input(i, j), j));
}
}
if (partial_sort_flag) {
std::partial_sort(
col_vec.begin(), col_vec.begin() + k, col_vec.end(),
[](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
} else {
std::nth_element(
col_vec.begin(), col_vec.begin() + k - 1, col_vec.end(),
[](const std::pair<T, Type>& l, const std::pair<T, Type>& r) {
return (!std::isnan(static_cast<double>(l.first)) &&
std::isnan(static_cast<double>(r.first))) ||
(l.first < r.first);
});
}
t_out[i] = col_vec[k - 1].first;
t_indices[i] = col_vec[k - 1].second;
}
}
template <typename T, typename Type>
static void kthvalueAssign(const Type& input_height, const Type& input_width,
const int& input_dim, const framework::Tensor* input,
const framework::Tensor* indices, T* output_data) {
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (Type i = 0; i < input_height; ++i) {
if (input_dim == 1) {
auto e_input = framework::EigenVector<T>::Flatten(*input);
auto e_indices = framework::EigenVector<Type>::Flatten(*indices);
output_data[i * input_width + e_indices(0)] = e_input(0);
} else {
auto e_input = framework::EigenMatrix<T>::Reshape(*input, input_dim - 1);
auto e_indices =
framework::EigenMatrix<Type>::Reshape(*indices, input_dim - 1);
output_data[i * input_width + e_indices(i, 0)] = e_input(i, 0);
}
}
}
template <typename DeviceContext, typename T>
class KthvalueCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* input = context.Input<framework::Tensor>("X");
auto* output = context.Output<framework::Tensor>("Out");
auto* indices = context.Output<framework::Tensor>("Indices");
const auto& in_dims = input->dims();
int k = static_cast<int>(context.Attr<int>("k"));
bool keepdim = static_cast<bool>(context.Attr<bool>("keepdim"));
int axis = static_cast<int>(context.Attr<int>("axis"));
if (axis < 0) axis += in_dims.size();
T* output_data = output->mutable_data<T>(context.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(context.GetPlace());
auto out_dims = output->dims();
if (axis == in_dims.size() - 1) {
const int64_t& input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t& input_width = in_dims[in_dims.size() - 1];
getKthvalue<T, int64_t>(input_height, input_width, in_dims.size(), input,
output_data, indices_data, k);
} else {
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.emplace_back(in_dims.size() - 1);
for (int i = axis + 1; i < in_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
if (!keepdim) {
std::vector<int> tmp_out_shape;
for (int i = 0; i < axis; i++) {
tmp_out_shape.emplace_back(in_dims[i]);
}
tmp_out_shape.emplace_back(1);
for (int i = axis + 1; i < in_dims.size(); i++) {
tmp_out_shape.emplace_back(in_dims[i]);
}
framework::DDim tmp_out_dims = framework::make_ddim(tmp_out_shape);
output->Resize(tmp_out_dims);
indices->Resize(tmp_out_dims);
}
framework::DDim trans_dims(in_dims);
framework::DDim trans_out_dims(in_dims);
for (size_t i = 0; i < trans.size(); i++) {
trans_dims[i] = in_dims[trans[i]];
trans_out_dims[i] = in_dims[trans[i]];
}
trans_out_dims[in_dims.size() - 1] = 1;
framework::Tensor trans_inp;
trans_inp.mutable_data<T>(trans_dims, context.GetPlace());
int ndims = trans.size();
auto& dev_context =
context.template device_context<platform::CPUDeviceContext>();
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_context, *input,
&trans_inp, trans);
const int64_t input_height = framework::product(
framework::slice_ddim(trans_dims, 0, trans_dims.size() - 1));
const int64_t input_width = trans_dims[trans_dims.size() - 1];
framework::Tensor tmp_out, tmp_indices;
T* t_out = tmp_out.mutable_data<T>(trans_out_dims, context.GetPlace());
auto* t_ind =
tmp_indices.mutable_data<int64_t>(trans_out_dims, context.GetPlace());
getKthvalue<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_inp, t_out, t_ind, k);
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_context, tmp_indices, indices, trans);
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_context, tmp_out,
output, trans);
if (!keepdim) {
output->Resize(out_dims);
indices->Resize(out_dims);
}
}
}
};
template <typename DeviceContext, typename T>
class KthvalueGradCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<framework::Tensor>("X");
auto* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* indices = context.Input<framework::Tensor>("Indices");
auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
int axis = static_cast<int>(context.Attr<int>("axis"));
bool keepdim = static_cast<bool>(context.Attr<bool>("keepdim"));
auto in_dims = x->dims();
auto out_dims = indices->dims();
axis = (axis < 0) ? (in_dims.size() + axis) : axis;
if (!keepdim) {
std::vector<int> tmp_out_shape;
for (int i = 0; i < axis; i++) {
tmp_out_shape.emplace_back(out_dims[i]);
}
tmp_out_shape.emplace_back(1);
for (int i = axis + 1; i < in_dims.size(); i++) {
tmp_out_shape.emplace_back(out_dims[i - 1]);
}
out_dims = framework::make_ddim(tmp_out_shape);
}
T* x_grad_data = x_grad->mutable_data<T>(context.GetPlace());
if (axis == in_dims.size() - 1) {
const int64_t input_height = framework::product(
framework::slice_ddim(in_dims, 0, in_dims.size() - 1));
const int64_t input_width = in_dims[in_dims.size() - 1];
memset(x_grad_data, 0, x_grad->numel() * sizeof(T));
if (keepdim) {
kthvalueAssign(input_height, input_width, in_dims.size(), out_grad,
indices, x_grad_data);
} else {
auto& dev_context =
context.template device_context<platform::CPUDeviceContext>();
framework::Tensor out_grad_tmp, indices_tmp;
out_grad_tmp.mutable_data<T>(out_grad->dims(), dev_context.GetPlace());
indices_tmp.mutable_data<int64_t>(indices->dims(),
dev_context.GetPlace());
framework::TensorCopy(*out_grad, dev_context.GetPlace(), dev_context,
&out_grad_tmp);
framework::TensorCopy(*indices, dev_context.GetPlace(), dev_context,
&indices_tmp);
out_grad_tmp.Resize(out_dims);
indices_tmp.Resize(out_dims);
kthvalueAssign(input_height, input_width, in_dims.size(), &out_grad_tmp,
&indices_tmp, x_grad_data);
}
} else {
std::vector<int> trans;
for (int i = 0; i < axis; i++) {
trans.emplace_back(i);
}
trans.emplace_back(out_dims.size() - 1);
for (int i = axis + 1; i < out_dims.size() - 1; i++) {
trans.emplace_back(i);
}
trans.emplace_back(axis);
framework::DDim trans_dims(out_dims);
framework::DDim trans_in_dims(in_dims);
for (size_t i = 0; i < trans.size(); i++) {
trans_dims[i] = out_dims[trans[i]];
trans_in_dims[i] = in_dims[trans[i]];
}
framework::Tensor trans_dO, trans_ind;
trans_dO.mutable_data<T>(trans_dims, context.GetPlace());
trans_ind.mutable_data<int64_t>(trans_dims, context.GetPlace());
int ndims = trans.size();
auto& dev_context =
context.template device_context<platform::CPUDeviceContext>();
if (keepdim) {
TransCompute<platform::CPUDeviceContext, T>(
ndims, dev_context, *out_grad, &trans_dO, trans);
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_context, *indices, &trans_ind, trans);
} else {
framework::Tensor out_grad_tmp, indices_tmp;
out_grad_tmp.mutable_data<T>(out_grad->dims(), dev_context.GetPlace());
indices_tmp.mutable_data<int64_t>(indices->dims(),
dev_context.GetPlace());
framework::TensorCopy(*out_grad, dev_context.GetPlace(), dev_context,
&out_grad_tmp);
framework::TensorCopy(*indices, dev_context.GetPlace(), dev_context,
&indices_tmp);
out_grad_tmp.Resize(out_dims);
indices_tmp.Resize(out_dims);
TransCompute<platform::CPUDeviceContext, T>(
ndims, dev_context, out_grad_tmp, &trans_dO, trans);
TransCompute<platform::CPUDeviceContext, int64_t>(
ndims, dev_context, indices_tmp, &trans_ind, trans);
}
const int64_t input_height = framework::product(
framework::slice_ddim(trans_in_dims, 0, trans_in_dims.size() - 1));
const int64_t input_width = trans_in_dims[trans_in_dims.size() - 1];
framework::Tensor tmp_out;
T* t_out = tmp_out.mutable_data<T>(trans_in_dims, context.GetPlace());
memset(t_out, 0, x_grad->numel() * sizeof(T));
kthvalueAssign<T, int64_t>(input_height, input_width, in_dims.size(),
&trans_dO, &trans_ind, t_out);
TransCompute<platform::CPUDeviceContext, T>(ndims, dev_context, tmp_out,
x_grad, trans);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -275,6 +275,7 @@ from .tensor.search import where # noqa: F401
from .tensor.search import index_select # noqa: F401
from .tensor.search import nonzero # noqa: F401
from .tensor.search import sort # noqa: F401
from .tensor.search import kthvalue # noqa: F401
from .tensor.search import mode # noqa: F401
from .tensor.to_string import set_printoptions # noqa: F401
......@@ -615,6 +616,7 @@ __all__ = [ # noqa
'moveaxis',
'repeat_interleave',
'clone',
'kthvalue',
'renorm',
'take_along_axis',
'put_along_axis',
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
import paddle
import paddle.fluid as fluid
def cal_kthvalue(x, k, axis, keepdim=False):
if axis < 0:
axis = len(x.shape) + axis
indices = np.argsort(x, axis=axis)
value = np.sort(x, axis=axis)
indices = indices.take(indices=k - 1, axis=axis)
value = value.take(indices=k - 1, axis=axis)
if keepdim:
indices = np.expand_dims(indices, axis)
value = np.expand_dims(value, axis)
return value, indices
class TestKthvalueOp(OpTest):
def init_args(self):
self.k = 5
self.axis = -1
def setUp(self):
self.op_type = "kthvalue"
self.dtype = np.float64
self.input_data = np.random.random((2, 1, 2, 4, 10))
self.init_args()
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis}
output, indices = cal_kthvalue(
self.input_data, k=self.k, axis=self.axis)
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
paddle.enable_static()
self.check_output()
def test_check_grad(self):
paddle.enable_static()
self.check_grad(set(['X']), 'Out')
class TestKthvalueOpWithKeepdim(OpTest):
def init_args(self):
self.k = 2
self.axis = 1
def setUp(self):
self.init_args()
self.op_type = "kthvalue"
self.dtype = np.float64
self.input_data = np.random.random((1, 3, 2, 4, 10))
self.inputs = {'X': self.input_data}
self.attrs = {'k': self.k, 'axis': self.axis, 'keepdim': True}
output, indices = cal_kthvalue(
self.input_data, k=self.k, axis=self.axis, keepdim=True)
self.outputs = {'Out': output, 'Indices': indices}
def test_check_output(self):
paddle.enable_static()
self.check_output()
def test_check_grad(self):
paddle.enable_static()
self.check_grad(set(['X']), 'Out')
class TestKthvalueOpKernels(unittest.TestCase):
def setUp(self):
self.axises = [2, -1]
def test_kthvalue_op(self):
paddle.disable_static()
def test_cpu_kernel():
shape = (2, 128, 10)
k = 2
paddle.set_device('cpu')
inputs = np.random.random(shape)
tensor = paddle.to_tensor(inputs)
for axis in self.axises:
value_expect, indice_expect = cal_kthvalue(inputs, k, axis)
v, inds = paddle.kthvalue(tensor, k, axis)
self.assertTrue(np.allclose(v.numpy(), value_expect))
self.assertTrue(np.allclose(inds.numpy(), indice_expect))
def test_gpu_kernel():
shape = (2, 30, 250)
k = 244
paddle.set_device('gpu')
inputs = np.random.random(shape)
tensor = paddle.to_tensor(inputs)
for axis in self.axises:
value_expect, indice_expect = cal_kthvalue(inputs, k, axis)
v, inds = paddle.kthvalue(tensor, k, axis)
self.assertTrue(np.allclose(v.numpy(), value_expect))
self.assertTrue(np.allclose(inds.numpy(), indice_expect))
test_cpu_kernel()
if fluid.core.is_compiled_with_cuda():
test_gpu_kernel()
class TestKthvalueOpWithNaN(unittest.TestCase):
def setUp(self):
paddle.disable_static()
self.x = paddle.uniform([2, 200, 10], dtype='float32')
def test_errors(self):
def test_nan_in_cpu_kernel():
paddle.set_device('cpu')
nan_position = 100
self.x[0, nan_position, 2] = float('nan')
v, inds = self.x.kthvalue(k=200, axis=1)
self.assertTrue(np.isnan(v[0, 2].numpy()[0]))
self.assertEqual(inds[0, 2].numpy()[0], nan_position)
def test_nan_in_gpu_kernel():
paddle.set_device('gpu')
nan_position = 100
self.x[0, nan_position, 2] = float('nan')
v, inds = self.x.kthvalue(k=200, axis=1)
self.assertTrue(np.isnan(v[0, 2].numpy()[0]))
self.assertEqual(inds[0, 2].numpy()[0], nan_position)
test_nan_in_cpu_kernel()
if fluid.core.is_compiled_with_cuda():
test_nan_in_gpu_kernel()
class TestKthvalueOpErrors(unittest.TestCase):
def setUp(self):
self.x = paddle.uniform([2, 10, 20, 25], dtype='float32')
def test_errors(self):
paddle.disable_static()
def test_k_lowrange_error():
self.x.kthvalue(k=0, axis=2)
self.assertRaises(ValueError, test_k_lowrange_error)
def test_k_uprange_error():
self.x.kthvalue(k=500, axis=2)
self.assertRaises(ValueError, test_k_uprange_error)
def test_dim_range_error():
self.x.kthvalue(k=10, axis=5)
self.assertRaises(ValueError, test_dim_range_error)
class TestModeOpInStatic(unittest.TestCase):
def setUp(self):
np.random.seed(666)
self.input_data = np.random.random((2, 20, 1, 2, 80)).astype(np.float64)
self.k = 10
def test_run_static(self):
paddle.enable_static()
with paddle.static.program_guard(paddle.static.Program(),
paddle.static.Program()):
input_tensor = paddle.static.data(
name="x", shape=[2, 20, 1, 2, 80], dtype="float64")
result = paddle.kthvalue(input_tensor, self.k, axis=1)
expect_value = cal_kthvalue(self.input_data, self.k, axis=1)[0]
exe = paddle.static.Executor(paddle.CPUPlace())
paddle_result = exe.run(feed={"x": self.input_data},
fetch_list=[result])[0]
self.assertTrue(np.allclose(paddle_result, expect_value))
if __name__ == '__main__':
unittest.main()
......@@ -251,6 +251,7 @@ from .search import nonzero # noqa: F401
from .search import sort # noqa: F401
from .search import index_sample # noqa: F401
from .search import masked_select # noqa: F401
from .search import kthvalue # noqa: F401
from .search import mode # noqa: F401
from .stat import mean # noqa: F401
......@@ -366,6 +367,7 @@ tensor_method_func = [ #noqa
'clip_',
'trace',
'kron',
'kthvalue',
'isfinite',
'isinf',
'isnan',
......
......@@ -891,3 +891,65 @@ def searchsorted(sorted_sequence,
"right": right})
return out
def kthvalue(x, k, axis=None, keepdim=False, name=None):
"""
This OP is used to find values and indices of the k-th smallest at the axis.
Args:
x(Tensor): A N-D Tensor with type float32, float64, int32, int64.
k(int): The k for the k-th smallest number to look for along the axis.
axis(int, optional): Axis to compute indices along. The effective range
is [-R, R), where R is x.ndim. when axis < 0, it works the same way
as axis + R. The default is None. And if the axis is None, it will computed as -1 by default.
keepdim(bool, optional): Whether to keep the given axis in output. If it is True, the dimensions will be same as input x and with size one in the axis. Otherwise the output dimentions is one fewer than x since the axis is squeezed. Default is False.
name (str, optional): Name for the operation (optional, default is None). For more information, please refer to :ref:`api_guide_Name`.
Returns:
tuple(Tensor), return the values and indices. The value data type is the same as the input `x`. The indices data type is int64.
Examples:
.. code-block:: python
import paddle
x = paddle.randn((2,3,2))
# Tensor(shape=[2, 3, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[[ 0.22954939, -0.01296274],
# [ 1.17135799, -0.34493217],
# [-0.19550551, -0.17573971]],
#
# [[ 0.15104349, -0.93965352],
# [ 0.14745511, 0.98209465],
# [ 0.10732264, -0.55859774]]])
y = paddle.kthvalue(x, 2, 1)
# (Tensor(shape=[2, 2], dtype=float32, place=CUDAPlace(0), stop_gradient=True,
# [[ 0.22954939, -0.17573971],
# [ 0.14745511, -0.55859774]]), Tensor(shape=[2, 2], dtype=int64, place=CUDAPlace(0), stop_gradient=True,
# [[0, 2],
# [1, 2]]))
"""
if in_dygraph_mode():
if axis is not None:
return _C_ops.kthvalue(x, 'k', k, "axis", axis, "keepdim", keepdim)
else:
return _C_ops.kthvalue(x, 'k', k, "keepdim", keepdim)
helper = LayerHelper("kthvalue", **locals())
inputs = {"X": [x]}
attrs = {'k': k}
if axis is not None:
attrs['axis'] = axis
values = helper.create_variable_for_type_inference(dtype=x.dtype)
indices = helper.create_variable_for_type_inference(dtype="int64")
helper.append_op(
type="kthvalue",
inputs=inputs,
outputs={"Out": [values],
"Indices": [indices]},
attrs=attrs)
indices.stop_gradient = True
return values, indices
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册