提交 92cfa2be 编写于 作者: Y Yibing Liu

Avoid using dynamic array in cuda kernel

上级 28a0ac53
...@@ -38,8 +38,8 @@ class ArgsortOp : public framework::OperatorWithKernel { ...@@ -38,8 +38,8 @@ class ArgsortOp : public framework::OperatorWithKernel {
"dimension %d.", "dimension %d.",
axis, num_dims); axis, num_dims);
PADDLE_ENFORCE(in_dims.size() + axis >= 0, PADDLE_ENFORCE(in_dims.size() + axis >= 0,
"Attr(axis) %d of ArgsortOp plus the number of Input(X)'s " "Attr(axis) %d of ArgsortOp plus the rank %d of Input(X) "
"dimensions %d must be nonnegative.", "must be nonnegative.",
axis, in_dims.size()); axis, in_dims.size());
ctx->SetOutputDim("Out", in_dims); ctx->SetOutputDim("Out", in_dims);
......
...@@ -31,8 +31,9 @@ __global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size, ...@@ -31,8 +31,9 @@ __global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size,
int64_t* med_ids) { int64_t* med_ids) {
int64_t index = threadIdx.x + blockDim.x * blockIdx.x; int64_t index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < n) { if (index < n) {
int64_t* shape_out_axis = new int64_t[dims_size - 1]; const int max_rank = 9; // Max rank of a tensor allow in Fluid
int64_t* dims_out_axis = new int64_t[dims_size - 1]; int64_t shape_out_axis[max_rank - 1] = {0};
int64_t dims_out_axis[max_rank - 1] = {0};
int64_t tmp = index; int64_t tmp = index;
int64_t pos_in_axis = 0; int64_t pos_in_axis = 0;
int64_t i = dims_size - 2; int64_t i = dims_size - 2;
...@@ -57,8 +58,6 @@ __global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size, ...@@ -57,8 +58,6 @@ __global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size,
int64_t traget_idx = group * dim_axis + pos_in_axis; int64_t traget_idx = group * dim_axis + pos_in_axis;
trg_idx[index] = traget_idx; trg_idx[index] = traget_idx;
med_ids[traget_idx] = pos_in_axis; med_ids[traget_idx] = pos_in_axis;
delete[] shape_out_axis;
delete[] dims_out_axis;
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册