From df48b43b91a67ee70df76630ebb560d2cf1d105a Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Fri, 27 Oct 2017 10:36:35 +0800 Subject: [PATCH] fix clear zero method and remove useless code --- paddle/operators/pool_cudnn_op.cu | 18 ++++-------------- .../v2/framework/tests/test_pool_max_op.py | 2 +- 2 files changed, 5 insertions(+), 15 deletions(-) diff --git a/paddle/operators/pool_cudnn_op.cu b/paddle/operators/pool_cudnn_op.cu index f9366eb7544..2db4837c8cb 100644 --- a/paddle/operators/pool_cudnn_op.cu +++ b/paddle/operators/pool_cudnn_op.cu @@ -117,8 +117,6 @@ class PoolCudnnGradOpKernel : public framework::OpKernel { // ------------------- cudnn descriptors --------------------- ScopedTensorDescriptor input_desc; ScopedTensorDescriptor output_desc; - ScopedTensorDescriptor input_grad_desc; - ScopedTensorDescriptor output_grad_desc; ScopedPoolingDescriptor pool_desc; DataLayout layout = DataLayout::kNCHW; @@ -126,9 +124,6 @@ class PoolCudnnGradOpKernel : public framework::OpKernel { input_desc.descriptor(layout, Dims2VectorPool(input->dims())); cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor(layout, Dims2VectorPool(output->dims())); - cudnnTensorDescriptor_t cudnn_output_grad_desc = - output_grad_desc.descriptor(layout, - Dims2VectorPool(output_grad->dims())); PoolingMode pooling_mode; if (pooling_type == "max") { @@ -146,18 +141,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel { if (input_grad) { T *input_grad_data = input_grad->mutable_data(ctx.GetPlace()); - auto temp = framework::EigenVector::Flatten(*input_grad); - temp.device(ctx.GetEigenDevice()) = - temp.constant(static_cast(0)); - - cudnnTensorDescriptor_t cudnn_input_grad_desc = - input_grad_desc.descriptor(layout, - Dims2VectorPool(input_grad->dims())); + math::SetConstant set_zero; + set_zero(ctx.device_context(), input_grad, static_cast(0)); PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward( handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data, - cudnn_output_grad_desc, output_grad_data, cudnn_input_desc, - input_data, &beta, cudnn_input_grad_desc, input_grad_data)); + cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data, + &beta, cudnn_input_desc, input_grad_data)); } } }; diff --git a/python/paddle/v2/framework/tests/test_pool_max_op.py b/python/paddle/v2/framework/tests/test_pool_max_op.py index b78f9bba05c..f0f8aa6089c 100644 --- a/python/paddle/v2/framework/tests/test_pool_max_op.py +++ b/python/paddle/v2/framework/tests/test_pool_max_op.py @@ -86,7 +86,7 @@ class TestMaxPoolWithIndex_Op(OpTest): 'strides': self.strides, 'paddings': self.paddings, 'ksize': self.ksize, - 'global_pooling': self.global_pool, + 'globalPooling': self.global_pool, } self.inputs = {'X': input} -- GitLab