提交 c7379a73 编写于 作者: Q qingqing01 提交者: whs

Fix top_k op (#14034)

1. Fix CUDA kernel when height is large than 2048.
2. Support input with more than 2D.
3. Fix unit test when k is large than 1.
4. Enhence unit testing.

test=develop
上级 d3e52551
......@@ -262,31 +262,31 @@ __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k,
int grid_dim, int num) {
__shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2];
const int tid = threadIdx.x;
const int warp = threadIdx.x / 32;
const int bid = blockIdx.x;
for (int i = bid; i < num; i += grid_dim) {
output += i * output_stride;
indices += i * k;
int top_num = k;
__shared__ int maxid[BlockSize / 2];
T* out = output + i * output_stride;
int64_t* inds = indices + i * k;
Pair<T> topk[MaxLength];
int beam = MaxLength;
Pair<T> max;
bool is_empty = false;
bool firststep = true;
for (int k = 0; k < MaxLength; k++) {
topk[k].set(-INFINITY, -1);
for (int j = 0; j < MaxLength; j++) {
topk[j].set(-INFINITY, -1);
}
while (k) {
while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>(
topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output,
&indices, &beam, &k, tid, warp);
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
&beam, &top_num, tid, warp);
}
}
}
......@@ -327,13 +327,15 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
size_t k = static_cast<int>(ctx.Attr<int>("k"));
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T?
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
size_t input_height = input->dims()[0];
size_t input_width = input->dims()[1];
framework::DDim inputdims = input->dims();
const size_t input_height = framework::product(
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t input_width = inputdims[inputdims.size() - 1];
if (k > input_width) k = input_width;
// NOTE: pass lds and dim same to input width.
......@@ -342,14 +344,12 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
const int kMaxHeight = 2048;
int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
auto& dev_ctx = ctx.cuda_device_context();
switch (GetDesiredBlockDim(input_width)) {
FIXED_BLOCK_DIM(
KeMatrixTopK<T, 5,
kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, output->dims()[1], indices_data, input_data,
input_width, input_width, static_cast<int>(k), gridx,
input_height));
output_data, k, indices_data, input_data, input_width,
input_width, static_cast<int>(k), gridx, input_height));
default:
PADDLE_THROW("Error");
}
......
......@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
// Get the top k elements of each row of input tensor
// FIXME: only deal with matrix(2d tensor).
auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices");
......@@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
auto eg_input = EigenMatrix<T>::From(*input);
// reshape input to a flattern matrix(like flat_inner_dims)
framework::DDim inputdims = input->dims();
const size_t row = framework::product(
......@@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel<T> {
const size_t col = inputdims[inputdims.size() - 1];
Eigen::DSizes<int, 2> flat2dims(row, col);
// NOTE: eigen shape doesn't affect paddle tensor.
eg_input.reshape(flat2dims);
auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
......
......@@ -21,22 +21,27 @@ from op_test import OpTest
class TestTopkOp(OpTest):
def setUp(self):
self.set_args()
self.op_type = "top_k"
k = 1
input = np.random.random((32, 84)).astype("float32")
output = np.ndarray((32, k))
indices = np.ndarray((32, k)).astype("int64")
k = self.top_k
input = np.random.random((self.row, k)).astype("float32")
output = np.ndarray((self.row, k))
indices = np.ndarray((self.row, k)).astype("int64")
self.inputs = {'X': input}
self.attrs = {'k': k}
for rowid in range(32):
for rowid in range(self.row):
row = input[rowid]
output[rowid] = np.sort(row)[-k:]
indices[rowid] = row.argsort()[-k:]
output[rowid] = np.sort(row)[::-1][:k]
indices[rowid] = row.argsort()[::-1][:k]
self.outputs = {'Out': output, 'Indices': indices}
def set_args(self):
self.row = 32
self.top_k = 1
def test_check_output(self):
self.check_output()
......@@ -50,14 +55,39 @@ class TestTopkOp3d(OpTest):
output = np.ndarray((64, k))
indices = np.ndarray((64, k)).astype("int64")
# FIXME: should use 'X': input for a 3d input
self.inputs = {'X': input_flat_2d}
self.inputs = {'X': input}
self.attrs = {'k': k}
for rowid in range(64):
row = input_flat_2d[rowid]
output[rowid] = np.sort(row)[-k:]
indices[rowid] = row.argsort()[-k:]
output[rowid] = np.sort(row)[::-1][:k]
indices[rowid] = row.argsort()[::-1][:k]
self.outputs = {
'Out': output.reshape((32, 2, k)),
'Indices': indices.reshape((32, 2, k))
}
def test_check_output(self):
self.check_output()
class TestTopkOp2(OpTest):
def setUp(self):
self.op_type = "top_k"
k = 1
m = 2056
input = np.random.random((m, 84)).astype("float32")
output = np.ndarray((m, k))
indices = np.ndarray((m, k)).astype("int64")
self.inputs = {'X': input}
self.attrs = {'k': k}
for rowid in range(m):
row = input[rowid]
output[rowid] = -np.sort(-row)[:k]
indices[rowid] = (-row).argsort()[:k]
self.outputs = {'Out': output, 'Indices': indices}
......@@ -65,5 +95,17 @@ class TestTopkOp3d(OpTest):
self.check_output()
class TestTopkOp3(TestTopkOp):
def set_args(self):
self.row = 2056
self.top_k = 3
class TestTopkOp4(TestTopkOp):
def set_args(self):
self.row = 40000
self.top_k = 1
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册