提交 3fbb692d 编写于 作者: 武毅 提交者: GitHub

Add topk op (#3760)

* init add

* add topk op

* someupdate

* fix style check

* add test py file

* update top k cuda kernel

* follow comments

* remove debug print

* fix casting error

* fix casting error

* fix casting error

* fix rename bug...

* fix travis
上级 2f40da09
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/top_k_op.h"
namespace paddle {
namespace operators {
class TopkOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"Input of TopkOP must be initialized.");
auto *input = ctx.Input<framework::Tensor>("X");
const int k = static_cast<int>(ctx.Attr<int>("k"));
PADDLE_ENFORCE_GE(k, 1, "k must >= 1");
PADDLE_ENFORCE_GE(input->dims().size(), 1, "input must have >= 1d shape");
PADDLE_ENFORCE_GE(input->dims()[input->dims().size() - 1], k,
"input must have >= k columns");
framework::DDim dims = input->dims();
dims[dims.size() - 1] = k;
ctx.Output<Tensor>("Out")->Resize(dims);
ctx.Output<Tensor>("Indices")->Resize(dims);
}
};
class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public:
TopkOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input of Topk op");
AddOutput("Out", "The output tensor of Topk op");
AddOutput("Indices", "The indices of Topk elements of input");
AddComment(
R"DOC(If the input is a vector (1d tensor), finds the k largest entries in the vector and outputs their values and indices as vectors. Thus values[j] is the j-th largest entry in input, and its index is indices[j].
For matrices, computes the top k entries in each row. )DOC");
AddAttr<int>("k",
"Number of top elements to look for along the last "
"dimension (along each row for matrices).")
.SetDefault(1);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(top_k, ops::TopkOp, ops::TopkOpMaker);
REGISTER_OP_CPU_KERNEL(top_k,
ops::TopkKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/platform/assert.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
struct Pair {
__device__ __forceinline__ Pair() {}
__device__ __forceinline__ Pair(T value, int id) : v(value), id(id) {}
__device__ __forceinline__ void set(T value, int id) {
v = value;
id = id;
}
__device__ __forceinline__ void operator=(const Pair<T>& in) {
v = in.v;
id = in.id;
}
__device__ __forceinline__ bool operator<(const T value) const {
return (v < value);
}
__device__ __forceinline__ bool operator<(const Pair<T>& in) const {
return (v < in.v) || ((v == in.v) && (id > in.id));
}
__device__ __forceinline__ bool operator>(const Pair<T>& in) const {
return (v > in.v) || ((v == in.v) && (id < in.id));
}
T v;
int id;
};
template <typename T>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T>& p,
int beam_size) {
for (int k = beam_size - 2; k >= 0; k--) {
if (topk[k] < p) {
topk[k + 1] = topk[k];
} else {
topk[k + 1] = p;
return;
}
}
topk[0] = p;
}
template <typename T, int beam_size>
__device__ __forceinline__ void AddTo(Pair<T> topk[], const Pair<T>& p) {
for (int k = beam_size - 2; k >= 0; k--) {
if (topk[k] < p) {
topk[k + 1] = topk[k];
} else {
topk[k + 1] = p;
return;
}
}
topk[0] = p;
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
int dim, int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < src[idx]) {
Pair<T> tmp(src[idx], idx);
AddTo<T>(topk, tmp, beam_size);
}
idx += BlockSize;
}
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* src, int idx,
int dim, const Pair<T>& max,
int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < src[idx]) {
Pair<T> tmp(src[idx], idx);
if (tmp < max) {
AddTo<T>(topk, tmp, beam_size);
}
}
idx += BlockSize;
}
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
int idx, int dim, int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < val[idx]) {
Pair<T> tmp(val[idx], col[idx]);
AddTo<T>(topk, tmp, beam_size);
}
idx += BlockSize;
}
}
template <typename T, int BlockSize>
__device__ __forceinline__ void GetTopK(Pair<T> topk[], const T* val, int* col,
int idx, int dim, const Pair<T>& max,
int beam_size) {
while (idx < dim) {
if (topk[beam_size - 1] < val[idx]) {
Pair<T> tmp(val[idx], col[idx]);
if (tmp < max) {
AddTo<T>(topk, tmp, beam_size);
}
}
idx += BlockSize;
}
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
int beam_size, const T* src,
bool& firstStep, bool& is_empty,
Pair<T>& max, int dim,
const int tid) {
if (beam > 0) {
int length = beam < beam_size ? beam : beam_size;
if (firstStep) {
firstStep = false;
GetTopK<T, BlockSize>(topk, src, tid, dim, length);
} else {
for (int k = 0; k < MaxLength; k++) {
if (k < MaxLength - beam) {
topk[k] = topk[k + beam];
} else {
topk[k].set(-INFINITY, -1);
}
}
if (!is_empty) {
GetTopK<T, BlockSize>(topk + MaxLength - beam, src, tid, dim, max,
length);
}
}
max = topk[MaxLength - 1];
if (max.v == -1) is_empty = true;
beam = 0;
}
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void ThreadGetTopK(Pair<T> topk[], int& beam,
int beam_size, const T* val,
int* col, bool& firstStep,
bool& is_empty, Pair<T>& max,
int dim, const int tid) {
if (beam > 0) {
int length = beam < beam_size ? beam : beam_size;
if (firstStep) {
firstStep = false;
GetTopK<T, BlockSize>(topk, val, col, tid, dim, length);
} else {
for (int k = 0; k < MaxLength; k++) {
if (k < MaxLength - beam) {
topk[k] = topk[k + beam];
} else {
topk[k].set(-INFINITY, -1);
}
}
if (!is_empty) {
GetTopK<T, BlockSize>(topk + MaxLength - beam, val, col, tid, dim, max,
length);
}
}
max = topk[MaxLength - 1];
if (max.v == -1) is_empty = true;
beam = 0;
}
}
template <typename T, int MaxLength, int BlockSize>
__device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
Pair<T> topk[], T** topVal,
int** topIds, int& beam, int& k,
const int tid, const int warp) {
while (true) {
__syncthreads();
if (tid < BlockSize / 2) {
if (sh_topk[tid] < sh_topk[tid + BlockSize / 2]) {
maxid[tid] = tid + BlockSize / 2;
} else {
maxid[tid] = tid;
}
}
__syncthreads();
for (int stride = BlockSize / 4; stride > 0; stride = stride / 2) {
if (tid < stride) {
if (sh_topk[maxid[tid]] < sh_topk[maxid[tid + stride]]) {
maxid[tid] = maxid[tid + stride];
}
}
__syncthreads();
}
__syncthreads();
if (tid == 0) {
**topVal = sh_topk[maxid[0]].v;
**topIds = sh_topk[maxid[0]].id;
(*topVal)++;
(*topIds)++;
}
if (tid == maxid[0]) beam++;
if (--k == 0) break;
__syncthreads();
if (tid == maxid[0]) {
if (beam < MaxLength) {
sh_topk[tid] = topk[beam];
}
}
if (maxid[0] / 32 == warp) {
if (__shfl(beam, (maxid[0]) % 32, 32) == MaxLength) break;
}
}
}
/**
* Each block compute one sample.
* In a block:
* 1. every thread get top MaxLength value;
* 2. merge to sh_topk, block reduce and get max value;
* 3. go to the second setp, until one thread's topk value is null;
* 4. go to the first setp, until get the topk value.
*/
template <typename T, int MaxLength, int BlockSize>
__global__ void KeMatrixTopK(T* output, int output_stride, int* indices,
const T* src, int lds, int dim, int k) {
__shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2];
const int tid = threadIdx.x;
const int warp = threadIdx.x / 32;
output += blockIdx.x * output_stride;
indices += blockIdx.x * k;
Pair<T> topk[MaxLength];
int beam = MaxLength;
Pair<T> max;
bool is_empty = false;
bool firststep = true;
for (int k = 0; k < MaxLength; k++) {
topk[k].set(-INFINITY, -1);
}
while (k) {
ThreadGetTopK<T, MaxLength, BlockSize>(topk, beam, k,
src + blockIdx.x * lds, firststep,
is_empty, max, dim, tid);
sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
&indices, beam, k, tid, warp);
}
}
template <typename T>
class TopkOpCUDAKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE(platform::is_gpu_place(ctx.GetPlace()),
"It must use GPUPlace.");
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
size_t k = static_cast<int>(ctx.Attr<int>("k"));
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T?
int* indices_data = indices->mutable_data<int>(ctx.GetPlace());
size_t input_height = input->dims()[0];
size_t input_width = input->dims()[1];
if (k > input_width) k = input_width;
// NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen.
// TODO(typhoonzero): launch kernel on specified stream.
// TODO(typhoonzero): refine this kernel.
dim3 threads(256, 1);
dim3 grid(input_height, 1);
KeMatrixTopK<T, 5, 256><<<grid, threads>>>(
output_data, output->dims()[1], indices_data, input_data, input_width,
input_width, int(k));
}
};
} // namespace operators
} // namespace paddle
REGISTER_OP_GPU_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm>
#include <iostream>
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename Place, typename T>
class TopkKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// Get the top k elements of each row of input tensor
// FIXME: only deal with matrix(2d tensor).
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
// k is determined by Attr
const size_t k = static_cast<int>(ctx.Attr<int>("k"));
T* output_data = output->mutable_data<T>(ctx.GetPlace());
T* indices_data = indices->mutable_data<T>(ctx.GetPlace());
auto eg_input = EigenMatrix<T>::From(*input);
// reshape input to a flattern matrix(like flat_inner_dims)
framework::DDim inputdims = input->dims();
const size_t row = framework::product(
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t col = inputdims[inputdims.size() - 1];
Eigen::DSizes<int, 2> flat2dims(row, col);
// NOTE: eigen shape doesn't affect paddle tensor.
eg_input.reshape(flat2dims);
for (size_t i = 0; i < row; i++) {
std::vector<std::pair<T, size_t>> vec;
for (size_t j = 0; j < col; j++) {
vec.push_back(std::pair<T, size_t>(eg_input(i, j), j));
}
std::partial_sort(
vec.begin(), vec.begin() + k, vec.end(),
[](const std::pair<T, size_t>& l, const std::pair<T, size_t>& r) {
return l.first > r.first;
});
for (size_t j = 0; j < k; j++) {
output_data[i * k + j] = vec[j].first;
indices_data[i * k + j] = vec[j].second;
}
}
}
};
} // namespace operators
} // namespace paddle
......@@ -49,6 +49,7 @@ USE_OP(minus);
USE_OP(cos_sim);
USE_CPU_ONLY_OP(gather);
USE_CPU_ONLY_OP(scatter);
USE_OP(top_k);
USE_OP(squared_l2_distance);
namespace paddle {
......
......@@ -37,7 +37,7 @@ Configuring cmake in /paddle/build ...
-DWITH_PYTHON=${WITH_PYTHON:-ON}
-DWITH_SWIG_PY=${WITH_SWIG_PY:-ON}
-DCUDNN_ROOT=/usr/
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-OFF}
-DWITH_STYLE_CHECK=${WITH_STYLE_CHECK:-ON}
-DWITH_TESTING=${WITH_TESTING:-ON}
-DCMAKE_EXPORT_COMPILE_COMMANDS=ON
========================================
......
......@@ -17,6 +17,7 @@ py_test(test_cross_entropy_op SRCS test_cross_entropy_op.py)
py_test(test_gather_op SRCS test_gather_op.py)
py_test(test_scatter_op SRCS test_scatter_op.py)
py_test(test_fill_zeros_like_op SRCS test_fill_zeros_like_op.py)
py_test(test_top_k_op SRCS test_top_k_op.py)
py_test(gradient_checker SRCS gradient_checker.py)
......
import unittest
import numpy as np
from gradient_checker import GradientChecker, create_op
from op_test_util import OpTestMeta
class TestTopkOp(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "top_k"
k = 1
input = np.random.random((32, 84)).astype("float32")
output = np.ndarray((32, k))
indices = np.ndarray((32, k))
self.inputs = {'X': input}
self.attrs = {'k': k}
for rowid in xrange(32):
row = input[rowid]
output[rowid] = np.sort(row)[-k:]
indices[rowid] = row.argsort()[-k:]
self.outputs = {'Out': output, 'Indices': indices}
class TestTopkOp3d(unittest.TestCase):
__metaclass__ = OpTestMeta
def setUp(self):
self.type = "top_k"
k = 1
input = np.random.random((32, 2, 84)).astype("float32")
input_flat_2d = input.reshape(64, 84)
output = np.ndarray((64, k))
indices = np.ndarray((64, k)).astype("int")
# FIXME: should use 'X': input for a 3d input
self.inputs = {'X': input_flat_2d}
self.attrs = {'k': k}
for rowid in xrange(64):
row = input_flat_2d[rowid]
output[rowid] = np.sort(row)[-k:]
indices[rowid] = row.argsort()[-k:]
self.outputs = {'Out': output, 'Indices': indices}
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册