提交 c8adc2c6 编写于 作者: D dzhwinter

cudnn version. staged.

上级 7141debe
...@@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -50,7 +50,7 @@ class TopkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override { void Make() override {
AddInput("X", "(Tensor) The input of Topk op"); AddInput("X", "(Tensor) The input of Topk op");
AddOutput("Out", "(Tensor) The output tensor of Topk op").Reuse("X"); AddOutput("Out", "(Tensor) The output tensor of Topk op");
AddOutput("Indices", "(Tensor) The indices of Topk elements of input"); AddOutput("Indices", "(Tensor) The indices of Topk elements of input");
AddComment(R"DOC( AddComment(R"DOC(
Top K operator Top K operator
......
...@@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid, ...@@ -256,36 +256,65 @@ __device__ __forceinline__ void BlockReduce(Pair<T>* sh_topk, int* maxid,
* 3. go to the second setp, until one thread's topk value is null; * 3. go to the second setp, until one thread's topk value is null;
* 4. go to the first setp, until get the topk value. * 4. go to the first setp, until get the topk value.
*/ */
template <typename T, int MaxLength, int BlockSize> template <typename T, int MaxLength, int BlockSize>
__global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices, __global__ void KeMatrixTopK(T* output, int output_stride, int64_t* indices,
const T* src, int lds, int dim, int k) { const T* src, int lds, int dim, int k,
int grid_dim, int num) {
__shared__ Pair<T> sh_topk[BlockSize]; __shared__ Pair<T> sh_topk[BlockSize];
__shared__ int maxid[BlockSize / 2];
const int tid = threadIdx.x; const int tid = threadIdx.x;
const int warp = threadIdx.x / 32; const int warp = threadIdx.x / 32;
output += blockIdx.x * output_stride;
indices += blockIdx.x * k;
const int bid = blockIdx.x;
for (int i = bid; i < num; i += grid_dim) {
int top_num = k;
__shared__ int maxid[BlockSize / 2];
T* out = output + i * output_stride;
int64_t* inds = indices + i * k;
Pair<T> topk[MaxLength]; Pair<T> topk[MaxLength];
int beam = MaxLength; int beam = MaxLength;
Pair<T> max; Pair<T> max;
bool is_empty = false; bool is_empty = false;
bool firststep = true; bool firststep = true;
for (int k = 0; k < MaxLength; k++) { for (int j = 0; j < MaxLength; j++) {
topk[k].set(-INFINITY, -1); topk[j].set(-INFINITY, -1);
} }
while (k) { while (top_num) {
ThreadGetTopK<T, MaxLength, BlockSize>(topk, &beam, k, ThreadGetTopK<T, MaxLength, BlockSize>(
src + blockIdx.x * lds, &firststep, topk, &beam, k, src + i * lds, &firststep, &is_empty, &max, dim, tid);
&is_empty, &max, dim, tid);
sh_topk[tid] = topk[0]; sh_topk[tid] = topk[0];
BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &output, BlockReduce<T, MaxLength, BlockSize>(sh_topk, maxid, topk, &out, &inds,
&indices, &beam, &k, tid, warp); &beam, &top_num, tid, warp);
}
} }
} }
inline static int GetDesiredBlockDim(int dim) {
if (dim > 128) {
return 256;
} else if (dim > 64) {
return 128;
} else if (dim > 32) {
return 64;
} else {
return 32;
}
}
#define FIXED_BLOCK_DIM_BASE(dim, ...) \
case (dim): { \
constexpr auto kBlockDim = (dim); \
__VA_ARGS__; \
} break
#define FIXED_BLOCK_DIM(...) \
FIXED_BLOCK_DIM_BASE(256, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(128, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(64, ##__VA_ARGS__); \
FIXED_BLOCK_DIM_BASE(32, ##__VA_ARGS__)
template <typename T> template <typename T>
class TopkOpCUDAKernel : public framework::OpKernel<T> { class TopkOpCUDAKernel : public framework::OpKernel<T> {
public: public:
...@@ -298,30 +327,38 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -298,30 +327,38 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
size_t k = static_cast<int>(ctx.Attr<int>("k")); size_t k = static_cast<int>(ctx.Attr<int>("k"));
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
// FIXME(typhoonzero): data is always converted to type T? // FIXME(typhoonzero): data is always converted to type T?
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
size_t input_height = input->dims()[0]; framework::DDim inputdims = input->dims();
size_t input_width = input->dims()[1]; const size_t input_height = framework::product(
framework::slice_ddim(inputdims, 0, inputdims.size() - 1));
const size_t input_width = inputdims[inputdims.size() - 1];
if (k > input_width) k = input_width; if (k > input_width) k = input_width;
// NOTE: pass lds and dim same to input width. // NOTE: pass lds and dim same to input width.
// NOTE: old matrix implementation of stride is different to eigen. // NOTE: old matrix implementation of stride is different to eigen.
// TODO(typhoonzero): refine this kernel. // TODO(typhoonzero): refine this kernel.
dim3 threads(256, 1); const int kMaxHeight = 2048;
dim3 grid(input_height, 1); int gridx = input_height < kMaxHeight ? input_height : kMaxHeight;
auto& dev_ctx = ctx.cuda_device_context();
KeMatrixTopK<T, 5, 256><<< switch (GetDesiredBlockDim(input_width)) {
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>( FIXED_BLOCK_DIM(
ctx.device_context()) KeMatrixTopK<T, 5,
.stream()>>>( kBlockDim><<<gridx, kBlockDim, 0, dev_ctx.stream()>>>(
output_data, output->dims()[1], indices_data, input_data, input_width, output_data, k, indices_data, input_data, input_width,
input_width, static_cast<int>(k)); input_width, static_cast<int>(k), gridx, input_height));
default:
PADDLE_THROW("Error");
}
} }
}; };
#undef FIXED_BLOCK_DIM_BASE
#undef FIXED_BLOCK_DIM
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
...@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -34,7 +34,6 @@ class TopkKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
// Get the top k elements of each row of input tensor // Get the top k elements of each row of input tensor
// FIXME: only deal with matrix(2d tensor).
auto* input = ctx.Input<Tensor>("X"); auto* input = ctx.Input<Tensor>("X");
auto* output = ctx.Output<Tensor>("Out"); auto* output = ctx.Output<Tensor>("Out");
auto* indices = ctx.Output<Tensor>("Indices"); auto* indices = ctx.Output<Tensor>("Indices");
...@@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -44,8 +43,6 @@ class TopkKernel : public framework::OpKernel<T> {
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace()); int64_t* indices_data = indices->mutable_data<int64_t>(ctx.GetPlace());
auto eg_input = EigenMatrix<T>::From(*input);
// reshape input to a flattern matrix(like flat_inner_dims) // reshape input to a flattern matrix(like flat_inner_dims)
framework::DDim inputdims = input->dims(); framework::DDim inputdims = input->dims();
const size_t row = framework::product( const size_t row = framework::product(
...@@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel<T> { ...@@ -53,7 +50,7 @@ class TopkKernel : public framework::OpKernel<T> {
const size_t col = inputdims[inputdims.size() - 1]; const size_t col = inputdims[inputdims.size() - 1];
Eigen::DSizes<int, 2> flat2dims(row, col); Eigen::DSizes<int, 2> flat2dims(row, col);
// NOTE: eigen shape doesn't affect paddle tensor. // NOTE: eigen shape doesn't affect paddle tensor.
eg_input.reshape(flat2dims); auto eg_input = EigenMatrix<T>::Reshape(*input, inputdims.size() - 1);
#ifdef PADDLE_WITH_MKLML #ifdef PADDLE_WITH_MKLML
#pragma omp parallel for #pragma omp parallel for
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册