From 03ccb9a461db7650fd1dc749f2f61a4df253bf31 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Thu, 15 Nov 2018 16:07:16 +0800 Subject: [PATCH] Optimize the stack operator --- paddle/fluid/operators/stack_op.h | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index d236c5b9437..f1692ae9563 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -147,16 +147,23 @@ class StackKernel : public framework::OpKernel { auto &dim = x[0]->dims(); for (auto i = 0; i < axis; ++i) pre *= dim[i]; for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; - int total_num = pre * n * post; - auto &dev_ctx = ctx.template device_context(); #ifdef __NVCC__ thrust::device_vector device_x_vec(x_datas); auto x_data_arr = device_x_vec.data().get(); #else auto x_data_arr = x_datas.data(); #endif - StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); + size_t x_offset = 0; + size_t y_offset = 0; + for (int i = 0; i < pre; i++) { + for (int j = 0; j < n; j++) { + std::memcpy(y_data + y_offset, x_data_arr[j] + x_offset, + post * sizeof(T)); + y_offset += post; + } + x_offset += post; + } #ifdef __NVCC__ // Wait() must be called because device_x_vec may be destructed before // kernel ends -- GitLab