提交 df48b43b 编写于 作者: C chengduoZH

fix clear zero method and remove useless code

上级 e5c167dc
......@@ -117,8 +117,6 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
// ------------------- cudnn descriptors ---------------------
ScopedTensorDescriptor input_desc;
ScopedTensorDescriptor output_desc;
ScopedTensorDescriptor input_grad_desc;
ScopedTensorDescriptor output_grad_desc;
ScopedPoolingDescriptor pool_desc;
DataLayout layout = DataLayout::kNCHW;
......@@ -126,9 +124,6 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
input_desc.descriptor<T>(layout, Dims2VectorPool(input->dims()));
cudnnTensorDescriptor_t cudnn_output_desc =
output_desc.descriptor<T>(layout, Dims2VectorPool(output->dims()));
cudnnTensorDescriptor_t cudnn_output_grad_desc =
output_grad_desc.descriptor<T>(layout,
Dims2VectorPool(output_grad->dims()));
PoolingMode pooling_mode;
if (pooling_type == "max") {
......@@ -146,18 +141,13 @@ class PoolCudnnGradOpKernel : public framework::OpKernel<T> {
if (input_grad) {
T *input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*input_grad);
temp.device(ctx.GetEigenDevice<paddle::platform::GPUPlace>()) =
temp.constant(static_cast<T>(0));
cudnnTensorDescriptor_t cudnn_input_grad_desc =
input_grad_desc.descriptor<T>(layout,
Dims2VectorPool(input_grad->dims()));
math::SetConstant<paddle::platform::GPUPlace, T> set_zero;
set_zero(ctx.device_context(), input_grad, static_cast<T>(0));
PADDLE_ENFORCE(platform::dynload::cudnnPoolingBackward(
handle, cudnn_pool_desc, &alpha, cudnn_output_desc, output_data,
cudnn_output_grad_desc, output_grad_data, cudnn_input_desc,
input_data, &beta, cudnn_input_grad_desc, input_grad_data));
cudnn_output_desc, output_grad_data, cudnn_input_desc, input_data,
&beta, cudnn_input_desc, input_grad_data));
}
}
};
......
......@@ -86,7 +86,7 @@ class TestMaxPoolWithIndex_Op(OpTest):
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'global_pooling': self.global_pool,
'globalPooling': self.global_pool,
}
self.inputs = {'X': input}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册