From c8adc2c6fec3414ee6be49205a51b6d9e32756d6 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Mon, 29 Oct 2018 10:40:02 +0800 Subject: [PATCH] cudnn version. staged. --- paddle/fluid/operators/top_k_op.cc | 2 +- paddle/fluid/operators/top_k_op.cu | 99 ++++++++++++++++++++---------- paddle/fluid/operators/top_k_op.h | 5 +- 3 files changed, 70 insertions(+), 36 deletions(-) diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index 4a8ac441cf..c17d1afc30 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "(Tensor) The input of Topk op"); - AddOutput("Out", "(Tensor) The output tensor of Topk op").Reuse("X"); + AddOutput("Out", "(Tensor) The output tensor of Topk op"); AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); AddComment(R"DOC( Top K operator diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index 9da8551eb2..0cad224ca8 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair* sh_topk, int* maxid, * 3. go to the second setp, until one thread's topk value is null; * 4. go to the first setp, until get the topk value. */ + template __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, - const T* src, int lds, int dim, int k) { + const T* src, int lds, int dim, int k, + int grid_dim, int num) { __shared__ Pair sh_topk[BlockSize]; - __shared__ int maxid[BlockSize / 2]; const int tid = threadIdx.x; const int warp = threadIdx.x / 32; - output += blockIdx.x * output_stride; - indices += blockIdx.x * k; - Pair topk[MaxLength]; - int beam = MaxLength; - Pair max; - bool is_empty = false; - bool firststep = true; + const int bid = blockIdx.x; + for (int i = bid; i < num; i += grid_dim) { + int top_num = k; + __shared__ int maxid[BlockSize / 2]; + T* out = output + i * output_stride; + int64_t* inds = indices + i * k; + Pair topk[MaxLength]; + int beam = MaxLength; + Pair max; + bool is_empty = false; + bool firststep = true; + + for (int j = 0; j < MaxLength; j++) { + topk[j].set(-INFINITY, -1); + } + while (top_num) { + ThreadGetTopK( + topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid); - for (int k = 0; k < MaxLength; k++) { - topk[k].set(-INFINITY, -1); + sh_topk[tid] = topk[0]; + BlockReduce(sh_topk, maxid, topk, &out, &inds, + &beam, &top_num, tid, warp); + } } - while (k) { - ThreadGetTopK(topk, &beam, k, - src + blockIdx.x * lds, &firststep, - &is_empty, &max, dim, tid); - - sh_topk[tid] = topk[0]; - BlockReduce(sh_topk, maxid, topk, &output, - &indices, &beam, &k, tid, warp); +} + +inline static int GetDesiredBlockDim(int dim) { + if (dim > 128) { + return 256; + } else if (dim > 64) { + return 128; + } else if (dim > 32) { + return 64; + } else { + return 32; } } +#define FIXED_BLOCK_DIM_BASE(dim, ...) \ + case (dim): { \ + constexpr auto kBlockDim = (dim); \ + __VA_ARGS__; \ + } break + +#define FIXED_BLOCK_DIM(...) \ + FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \ + FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__) + template class TopkOpCUDAKernel : public framework::OpKernel { public: @@ -298,30 +327,38 @@ class TopkOpCUDAKernel : public framework::OpKernel { size_t k = static_cast(ctx.Attr("k")); const T* input_data = input->data(); - T* output_data = output->mutable_data(ctx.GetPlace()); // FIXME(typhoonzero): data is always converted to type T? int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); - size_t input_height = input->dims()[0]; - size_t input_width = input->dims()[1]; + framework::DDim inputdims = input->dims(); + const size_t input_height = framework::product( + framework::slice_ddim(inputdims, 0, inputdims.size() - 1)); + const size_t input_width = inputdims[inputdims.size() - 1]; + if (k > input_width) k = input_width; // NOTE: pass lds and dim same to input width. // NOTE: old matrix implementation of stride is different to eigen. // TODO(typhoonzero): refine this kernel. - dim3 threads(256, 1); - dim3 grid(input_height, 1); - - KeMatrixTopK<<< - grid, threads, 0, reinterpret_cast( - ctx.device_context()) - .stream()>>>( - output_data, output->dims()[1], indices_data, input_data, input_width, - input_width, static_cast(k)); + const int kMaxHeight = 2048; + int gridx = input_height < kMaxHeight ? input_height : kMaxHeight; + auto& dev_ctx = ctx.cuda_device_context(); + switch (GetDesiredBlockDim(input_width)) { + FIXED_BLOCK_DIM( + KeMatrixTopK<<>>( + output_data, k, indices_data, input_data, input_width, + input_width, static_cast(k), gridx, input_height)); + default: + PADDLE_THROW("Error"); + } } }; +#undef FIXED_BLOCK_DIM_BASE +#undef FIXED_BLOCK_DIM + } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/top_k_op.h b/paddle/fluid/operators/top_k_op.h index 054dd48199..76ece57b39 100644 --- a/paddle/fluid/operators/top_k_op.h +++ b/paddle/fluid/operators/top_k_op.h @@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { // Get the top k elements of each row of input tensor - // FIXME: only deal with matrix(2d tensor). auto* input = ctx.Input("X"); auto* output = ctx.Output("Out"); auto* indices = ctx.Output("Indices"); @@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel { T* output_data = output->mutable_data(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data(ctx.GetPlace()); - auto eg_input = EigenMatrix::From(*input); - // reshape input to a flattern matrix(like flat_inner_dims) framework::DDim inputdims = input->dims(); const size_t row = framework::product( @@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel { const size_t col = inputdims[inputdims.size() - 1]; Eigen::DSizes flat2dims(row, col); // NOTE: eigen shape doesn't affect paddle tensor. - eg_input.reshape(flat2dims); + auto eg_input = EigenMatrix::Reshape(*input, inputdims.size() - 1); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for -- GitLab