提交 df48b43b 编写于 作者: C chengduoZH

fix clear zero method and remove useless code

上级 e5c167dc
...@@ -117,8 +117,6 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> { ...@@ -117,8 +117,6 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn descriptors --------------------- // ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc; ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc; ScopedTensorDescriptor output_desc;
ScopedTensorDescriptor input_grad_desc;
ScopedTensorDescriptor output_grad_desc;
ScopedPoolingDescriptor pool_desc; ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW; DataLayout layout = DataLayout::kNCHW;
...@@ -126,9 +124,6 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> { ...@@ -126,9 +124,6 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
input_desc.descriptor<T>(layout, Dims2VectorPool(input->dims())); input_desc.descriptor<T>(layout, Dims2VectorPool(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc = cudnnTensorDescriptor_t cudnn_output_desc =
output_desc.descriptor<T>(layout, Dims2VectorPool(output->dims())); output_desc.descriptor<T>(layout, Dims2VectorPool(output->dims()));
cudnnTensorDescriptor_t cudnn_output_grad_desc =
output_grad_desc.descriptor<T>(layout,
Dims2VectorPool(output_grad->dims()));
PoolingMode pooling_mode; PoolingMode pooling_mode;
if (pooling_type == "max") { if (pooling_type == "max") {
...@@ -146,18 +141,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> { ...@@ -146,18 +141,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
if (input_grad) { if (input_grad) {
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace()); T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*input_grad); math::SetConstant<paddle::platform::GPUPlace, T> set_zero;
temp.device(ctx.GetEigenDevice<paddle::platform::GPUPlace>()) = set_zero(ctx.device_context(), input_grad, static_cast<T>(0));
temp.constant(static_cast<T>(0));
cudnnTensorDescriptor_t cudnn_input_grad_desc =
input_grad_desc.descriptor<T>(layout,
Dims2VectorPool(input_grad->dims()));
PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward( PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward(
handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data, handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data,
cudnn_output_grad_desc, output_grad_data, cudnn_input_desc, cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data,
input_data, &beta, cudnn_input_grad_desc, input_grad_data)); &beta, cudnn_input_desc, input_grad_data));
} }
} }
}; };
......
...@@ -86,7 +86,7 @@ class TestMaxPoolWithIndex_Op(OpTest): ...@@ -86,7 +86,7 @@ class TestMaxPoolWithIndex_Op(OpTest):
'strides': self.strides, 'strides': self.strides,
'paddings': self.paddings, 'paddings': self.paddings,
'ksize': self.ksize, 'ksize': self.ksize,
'global_pooling': self.global_pool, 'globalPooling': self.global_pool,
} }
self.inputs = {'X': input} self.inputs = {'X': input}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册