From a906a361be831b9b425a9f197036fef506020857 Mon Sep 17 00:00:00 2001 From: Yihua Xu Date: Tue, 20 Nov 2018 08:30:27 +0800 Subject: [PATCH] Add the macro for NVCC (test=develop) --- paddle/fluid/operators/stack_op.h | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/stack_op.h b/paddle/fluid/operators/stack_op.h index f1692ae95..3d132e439 100644 --- a/paddle/fluid/operators/stack_op.h +++ b/paddle/fluid/operators/stack_op.h @@ -149,11 +149,20 @@ class StackKernel : public framework::OpKernel { for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; #ifdef __NVCC__ + int total_num = pre * n * post; + auto &dev_ctx = ctx.template device_context(); + thrust::device_vector device_x_vec(x_datas); auto x_data_arr = device_x_vec.data().get(); + + StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); + + // Wait() must be called because device_x_vec may be destructed before + // kernel ends + dev_ctx.Wait(); #else auto x_data_arr = x_datas.data(); -#endif + size_t x_offset = 0; size_t y_offset = 0; for (int i = 0; i < pre; i++) { @@ -164,10 +173,6 @@ class StackKernel : public framework::OpKernel { } x_offset += post; } -#ifdef __NVCC__ - // Wait() must be called because device_x_vec may be destructed before - // kernel ends - dev_ctx.Wait(); #endif } }; -- GitLab