提交 92cfa2be 编写于 作者: Y Yibing Liu

Avoid using dynamic array in cuda kernel

上级 28a0ac53
......@@ -38,8 +38,8 @@ class ArgsortOp : public framework::OperatorWithKernel {
"dimension %d.",
axis, num_dims);
PADDLE_ENFORCE(in_dims.size() + axis >= 0,
"Attr(axis) %d of ArgsortOp plus the number of Input(X)'s "
"dimensions %d must be nonnegative.",
"Attr(axis) %d of ArgsortOp plus the rank %d of Input(X) "
"must be nonnegative.",
axis, in_dims.size());
ctx->SetOutputDim("Out", in_dims);
......
......@@ -31,8 +31,9 @@ __global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size,
int64_t* med_ids) {
int64_t index = threadIdx.x + blockDim.x * blockIdx.x;
if (index < n) {
int64_t* shape_out_axis = new int64_t[dims_size - 1];
int64_t* dims_out_axis = new int64_t[dims_size - 1];
const int max_rank = 9; // Max rank of a tensor allow in Fluid
int64_t shape_out_axis[max_rank - 1] = {0};
int64_t dims_out_axis[max_rank - 1] = {0};
int64_t tmp = index;
int64_t pos_in_axis = 0;
int64_t i = dims_size - 2;
......@@ -57,8 +58,6 @@ __global__ void ComputeTargetIdx(const int64_t* in_dims, int dims_size,
int64_t traget_idx = group * dim_axis + pos_in_axis;
trg_idx[index] = traget_idx;
med_ids[traget_idx] = pos_in_axis;
delete[] shape_out_axis;
delete[] dims_out_axis;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册