提交 14fb15b6 编写于 作者: W wanghaoshuang

Remove const cast for device context

上级 a3c3b786
...@@ -54,12 +54,11 @@ class ClipGradientOpCUDAKernel : public framework::OpKernel { ...@@ -54,12 +54,11 @@ class ClipGradientOpCUDAKernel : public framework::OpKernel {
int D = d_x->dims()[1]; int D = d_x->dims()[1];
int block = 512; int block = 512;
int grid = (N * D + block - 1) / block; int grid = (N * D + block - 1) / block;
auto* device_context = ClipGradientKernel<T><<<
const_cast<platform::DeviceContext*>(context.device_context_); grid, block, 0, reinterpret_cast<const platform::CUDADeviceContext&>(
ClipGradientKernel< context.device_context())
T><<<grid, block, 0, .stream()>>>(count, min, max, x_data, d_out_data,
reinterpret_cast<platform::CUDADeviceContext*>(device_context) d_x_data);
->stream()>>>(count, min, max, x_data, d_out_data, d_x_data);
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册