未验证 提交 c1488b17 编写于 作者: Y Yibing Liu 提交者: GitHub

Merge pull request #12940 from sneaxiy/stack_op

Speedup stack_op
...@@ -150,12 +150,7 @@ class StackKernel : public framework::OpKernel<T> { ...@@ -150,12 +150,7 @@ class StackKernel : public framework::OpKernel<T> {
int total_num = pre * n * post; int total_num = pre * n * post;
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
constexpr auto kMaxThreshold = 16;
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
n > kMaxThreshold) {
#ifdef __NVCC__ #ifdef __NVCC__
VLOG(10) << "Stack more than " << kMaxThreshold
<< " tensors on GPU may be slow.";
thrust::device_vector<const T *> device_x_vec(x_datas); thrust::device_vector<const T *> device_x_vec(x_datas);
auto x_data_arr = device_x_vec.data().get(); auto x_data_arr = device_x_vec.data().get();
#else #else
...@@ -168,14 +163,6 @@ class StackKernel : public framework::OpKernel<T> { ...@@ -168,14 +163,6 @@ class StackKernel : public framework::OpKernel<T> {
dev_ctx.Wait(); dev_ctx.Wait();
#endif #endif
} }
#ifdef __NVCC__
else { // NOLINT
framework::Array<const T *, kMaxThreshold> x_data_arr;
for (int i = 0; i < n; ++i) x_data_arr[i] = x_datas[i];
StackFunctorForRange(dev_ctx, x_data_arr, y_data, total_num, n, post);
}
#endif
}
}; };
class StackOpGrad : public framework::OperatorWithKernel { class StackOpGrad : public framework::OperatorWithKernel {
...@@ -244,34 +231,19 @@ class StackGradKernel : public framework::OpKernel<T> { ...@@ -244,34 +231,19 @@ class StackGradKernel : public framework::OpKernel<T> {
int post = total_num / (n * pre); int post = total_num / (n * pre);
auto &dev_ctx = ctx.template device_context<DeviceContext>(); auto &dev_ctx = ctx.template device_context<DeviceContext>();
constexpr auto kMaxThreshold = 16;
if (std::is_same<DeviceContext, platform::CPUDeviceContext>::value ||
n > kMaxThreshold) {
#ifdef __NVCC__ #ifdef __NVCC__
VLOG(10) << "Stack more than " << kMaxThreshold
<< " tensors on GPU may be slow.";
thrust::device_vector<T *> device_dx_vec(dx_datas); thrust::device_vector<T *> device_dx_vec(dx_datas);
auto dx_data_arr = device_dx_vec.data().get(); auto dx_data_arr = device_dx_vec.data().get();
#else #else
auto dx_data_arr = dx_datas.data(); auto dx_data_arr = dx_datas.data();
#endif #endif
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n, post);
post);
#ifdef __NVCC__ #ifdef __NVCC__
// Wait() must be called because device_dx_vec may be destructed before // Wait() must be called because device_dx_vec may be destructed before
// kernel ends // kernel ends
dev_ctx.Wait(); dev_ctx.Wait();
#endif #endif
} }
#ifdef __NVCC__
else { // NOLINT
framework::Array<T *, kMaxThreshold> dx_data_arr;
for (int i = 0; i < n; ++i) dx_data_arr[i] = dx_datas[i];
StackGradFunctorForRange(dev_ctx, dx_data_arr, dy_data, total_num, n,
post);
}
#endif
}
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册