提交 8ebfc153 编写于 作者: K Kexin Zhao

update

上级 bfbc25bd
...@@ -28,6 +28,8 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; ...@@ -28,6 +28,8 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedFilterDescriptor = platform::ScopedFilterDescriptor; using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor; using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
static_cast<size_t>(1024) * 1024 * 1024; static_cast<size_t>(1024) * 1024 * 1024;
...@@ -134,8 +136,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> { ...@@ -134,8 +136,7 @@ class CUDNNConvOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace()); platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv forward --------------------- // ------------------- cudnn conv forward ---------------------
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f, ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
beta = 0.0f;
for (int i = 0; i < groups; i++) { for (int i = 0; i < groups; i++) {
PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward( PADDLE_ENFORCE(platform::dynload::cudnnConvolutionForward(
handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in, handle, &alpha, cudnn_input_desc, input_data + i * group_offset_in,
...@@ -282,8 +283,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> { ...@@ -282,8 +283,7 @@ class CUDNNConvGradOpKernel : public framework::OpKernel<T> {
platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace()); platform::CUDAPlace gpu = boost::get<platform::CUDAPlace>(ctx.GetPlace());
cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes); cudnn_workspace = paddle::memory::Alloc(gpu, workspace_size_in_bytes);
// ------------------- cudnn conv backward data --------------------- // ------------------- cudnn conv backward data ---------------------
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f, ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
beta = 0.0f;
if (input_grad) { if (input_grad) {
T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T* input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
......
...@@ -24,6 +24,8 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor; ...@@ -24,6 +24,8 @@ using ScopedTensorDescriptor = platform::ScopedTensorDescriptor;
using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor; using ScopedPoolingDescriptor = platform::ScopedPoolingDescriptor;
using DataLayout = platform::DataLayout; using DataLayout = platform::DataLayout;
using PoolingMode = platform::PoolingMode; using PoolingMode = platform::PoolingMode;
template <typename T>
using ScalingParamType = typename platform::CudnnDataType<T>::ScalingParamType;
template <typename T> template <typename T>
class PoolCUDNNOpKernel : public framework::OpKernel<T> { class PoolCUDNNOpKernel : public framework::OpKernel<T> {
...@@ -78,9 +80,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> { ...@@ -78,9 +80,7 @@ class PoolCUDNNOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn pool algorithm --------------------- // ------------------- cudnn pool algorithm ---------------------
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f, ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
beta = 0.0f;
PADDLE_ENFORCE(platform::dynload::cudnnPoolingForward( PADDLE_ENFORCE(platform::dynload::cudnnPoolingForward(
handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta, handle, cudnn_pool_desc, &alpha, cudnn_input_desc, input_data, &beta,
cudnn_output_desc, output_data)); cudnn_output_desc, output_data));
...@@ -145,9 +145,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> { ...@@ -145,9 +145,7 @@ class PoolCUDNNGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn pool algorithm --------------------- // ------------------- cudnn pool algorithm ---------------------
auto handle = ctx.cuda_device_context().cudnn_handle(); auto handle = ctx.cuda_device_context().cudnn_handle();
typename platform::CudnnDataType<T>::ScalingParamType alpha = 1.0f, ScalingParamType<T> alpha = 1.0f, beta = 0.0f;
beta = 0.0f;
if (input_grad) { if (input_grad) {
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
// Because beta is zero, it is unnecessary to reset input_grad. // Because beta is zero, it is unnecessary to reset input_grad.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册