未验证 提交 1d3e9bde 编写于 作者: T Tao Luo 提交者: GitHub

Merge pull request #14488 from yihuaxu/develop_7a64d48f_stack_opt

Optimize the stack operator
...@@ -147,20 +147,32 @@ class StackKernel : public framework::OpKernel<T> { ...@@ -147,20 +147,32 @@ class StackKernel : public framework::OpKernel<T> {
auto &dim = x[0]->dims(); auto &dim = x[0]->dims();
for (auto i = 0; i < axis; ++i) pre *= dim[i]; for (auto i = 0; i < axis; ++i) pre *= dim[i];
for (auto i = axis; i < dim.size(); ++i) post *= dim[i]; for (auto i = axis; i < dim.size(); ++i) post *= dim[i];
int total_num = pre * n * post;
auto &dev_ctx = ctx.template device_context<DeviceContext>();
#ifdef __NVCC__ #ifdef __NVCC__
int total_num = pre * n * post;
auto &dev_ctx = ctx.template device_context<DeviceContext>();
thrust::device_vector<const T *> device_x_vec(x_datas); thrust::device_vector<const T *> device_x_vec(x_datas);
auto x_data_arr = device_x_vec.data().get(); auto x_data_arr = device_x_vec.data().get();
#else
auto x_data_arr = x_datas.data();
#endif
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post); StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
#ifdef __NVCC__
// Wait() must be called because device_x_vec may be destructed before // Wait() must be called because device_x_vec may be destructed before
// kernel ends // kernel ends
dev_ctx.Wait(); dev_ctx.Wait();
#else
auto x_data_arr = x_datas.data();
size_t x_offset = 0;
size_t y_offset = 0;
for (int i = 0; i < pre; i++) {
for (int j = 0; j < n; j++) {
std::memcpy(y_data + y_offset, x_data_arr[j] + x_offset,
post * sizeof(T));
y_offset += post;
}
x_offset += post;
}
#endif #endif
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册