提交 272f3e6d 编写于 作者: T typhoonzero

refine get cuda context

上级 ff4c20e0
...@@ -298,11 +298,10 @@ class ExecutionContext { ...@@ -298,11 +298,10 @@ class ExecutionContext {
} }
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
const platform::CUDADeviceContext& cuda_device_context() const { const inline platform::CUDADeviceContext& cuda_device_context() const {
PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace())); PADDLE_ENFORCE(platform::is_gpu_place(device_context_.GetPlace()));
auto cuda_ctx = return *reinterpret_cast<const platform::CUDADeviceContext*>(
reinterpret_cast<const platform::CUDADeviceContext*>(&device_context_); &device_context_);
return *cuda_ctx;
} }
#endif #endif
......
...@@ -72,11 +72,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> { ...@@ -72,11 +72,8 @@ class AccuracyOpCUDAKernel : public framework::OpKernel<T> {
} }
AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<< AccuracyCudaKernel<PADDLE_CUDA_NUM_THREADS><<<
1, PADDLE_CUDA_NUM_THREADS, 0, 1, PADDLE_CUDA_NUM_THREADS, 0, ctx.cuda_device_context().stream()>>>(
reinterpret_cast<const platform::CUDADeviceContext&>( num_samples, infer_width, indices_data, label_data, accuracy_data);
ctx.device_context())
.stream()>>>(num_samples, infer_width, indices_data, label_data,
accuracy_data);
} }
}; };
......
...@@ -27,7 +27,6 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; ...@@ -27,7 +27,6 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
using CUDADeviceContext = platform::CUDADeviceContext;
static constexpr size_t kConvCudnnWorkspaceLimitBytes = 1024 * 1024 * 1024; static constexpr size_t kConvCudnnWorkspaceLimitBytes = 1024 * 1024 * 1024;
......
...@@ -27,7 +27,6 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; ...@@ -27,7 +27,6 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
using CUDADeviceContext = platform::CUDADeviceContext;
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024; static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024;
......
...@@ -130,9 +130,7 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> { ...@@ -130,9 +130,7 @@ class ConvShiftKernel<platform::GPUPlace, T> : public framework::OpKernel<T> {
dim3 grid_dim(num_x_blocks, batch_size); dim3 grid_dim(num_x_blocks, batch_size);
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>( auto stream = context.cuda_device_context().stream();
context.device_context())
.stream();
conv_shift_forward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>( conv_shift_forward<T><<<grid_dim, x_per_block, mem_per_block, stream>>>(
x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size); x_data, y_data, out_data, x_width, y_width, y_half_width, batch_size);
...@@ -159,9 +157,7 @@ class ConvShiftGradKernel<platform::GPUPlace, T> ...@@ -159,9 +157,7 @@ class ConvShiftGradKernel<platform::GPUPlace, T>
int y_width = Y->dims()[1]; int y_width = Y->dims()[1];
int y_half_width = (y_width - 1) / 2; int y_half_width = (y_width - 1) / 2;
auto stream = reinterpret_cast<const platform::CUDADeviceContext &>( auto stream = context.cuda_device_context().stream();
context.device_context())
.stream();
const int x_per_block = 256; const int x_per_block = 256;
int num_x_blocks = div_up(x_width, x_per_block); int num_x_blocks = div_up(x_width, x_per_block);
......
...@@ -82,24 +82,19 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> { ...@@ -82,24 +82,19 @@ class CrossEntropyGradientOpCUDAKernel : public framework::OpKernel<T> {
int block = 512; int block = 512;
int grid = (batch_size * class_num + block - 1) / block; int grid = (batch_size * class_num + block - 1) / block;
auto stream = ctx.cuda_device_context().stream();
if (ctx.Attr<bool>("soft_label")) { if (ctx.Attr<bool>("soft_label")) {
auto* label_data = label->data<T>(); auto* label_data = label->data<T>();
SoftCrossEntropyGradientKernel<T><<< SoftCrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>( dx_data, dy_data, x_data, label_data, batch_size, class_num);
ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data,
batch_size, class_num);
} else { } else {
math::SetConstant<platform::GPUPlace, T> functor; math::SetConstant<platform::GPUPlace, T> functor;
functor(ctx.device_context(), dx, 0); functor(ctx.device_context(), dx, 0);
auto* label_data = label->data<int64_t>(); auto* label_data = label->data<int64_t>();
grid = (batch_size + block - 1) / block; grid = (batch_size + block - 1) / block;
CrossEntropyGradientKernel<T><<< CrossEntropyGradientKernel<T><<<grid, block, 0, stream>>>(
grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>( dx_data, dy_data, x_data, label_data, batch_size, class_num);
ctx.device_context())
.stream()>>>(dx_data, dy_data, x_data, label_data,
batch_size, class_num);
} }
} }
}; };
......
...@@ -74,10 +74,9 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> { ...@@ -74,10 +74,9 @@ class LookupTableCUDAKernel : public framework::OpKernel<T> {
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
LookupTable<T, 128, 8, 8><<< LookupTable<T, 128, 8,
grids, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>( 8><<<grids, threads, 0, context.device_context().stream()>>>(
context.device_context()) output, table, ids, N, K, D);
.stream()>>>(output, table, ids, N, K, D);
} }
}; };
...@@ -95,9 +94,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -95,9 +94,7 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
auto* ids_data = ids->data<int64_t>(); auto* ids_data = ids->data<int64_t>();
auto ids_dim = ids->dims(); auto ids_dim = ids->dims();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = context.cuda_device_context().stream();
context.device_context())
.stream();
// copy GPU memory to CPU pinned memory // copy GPU memory to CPU pinned memory
framework::Vector<int64_t> new_rows; framework::Vector<int64_t> new_rows;
new_rows.resize(ids_dim[0]); new_rows.resize(ids_dim[0]);
...@@ -136,11 +133,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> { ...@@ -136,11 +133,10 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(8, 1); dim3 grids(8, 1);
LookupTableGrad<T, 128, 8, LookupTableGrad<
8><<<grids, threads, 0, T, 128, 8,
reinterpret_cast<const platform::CUDADeviceContext&>( 8><<<grids, threads, 0, context.device_context().stream()>>>(
context.device_context()) d_table, d_output, ids, N, K, D);
.stream()>>>(d_table, d_output, ids, N, K, D);
} }
} }
}; };
......
...@@ -35,9 +35,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> { ...@@ -35,9 +35,7 @@ class MultiplexGPUKernel : public framework::OpKernel<T> {
Tensor index_t_cpu; Tensor index_t_cpu;
index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context()); index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context());
auto* index = index_t_cpu.data<int32_t>(); auto* index = index_t_cpu.data<int32_t>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = ctx.cuda_device_context().stream();
ctx.device_context())
.stream();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
int32_t k = index[i]; int32_t k = index[i];
...@@ -73,9 +71,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> { ...@@ -73,9 +71,7 @@ class MultiplexGradGPUKernel : public framework::OpKernel<T> {
index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context()); index_t_cpu.CopyFrom(*ids, platform::CPUPlace(), ctx.device_context());
auto* index = index_t_cpu.data<int32_t>(); auto* index = index_t_cpu.data<int32_t>();
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = ctx.device_context().stream();
ctx.device_context())
.stream();
Place place = boost::get<Place>(ctx.GetPlace()); Place place = boost::get<Place>(ctx.GetPlace());
for (auto i = 0; i < rows; i++) { for (auto i = 0; i < rows; i++) {
size_t k = static_cast<size_t>(index[i]); size_t k = static_cast<size_t>(index[i]);
......
...@@ -64,9 +64,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> { ...@@ -64,9 +64,7 @@ class NCCLAllReduceKernel : public framework::OpKernel<T> {
auto* comm = ctx.Input<Communicator>("Communicator"); auto* comm = ctx.Input<Communicator>("Communicator");
auto stream = reinterpret_cast<const platform::CUDADeviceContext&>( auto stream = ctx.cuda_device_context().stream();
ctx.device_context())
.stream();
// device id // device id
int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId(); int gpu_id = boost::get<platform::GPUPlace>(ctx.GetPlace()).GetDeviceId();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册