From bb546cf13e6076b28d748b526a4486021b0d2b84 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 21 Sep 2017 13:35:46 +0800 Subject: [PATCH] Bug fix. --- paddle/operators/gemm_conv2d_op.h | 22 +++++++++------------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h index 72de0a5cf3e..f9215b46fcd 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -75,8 +75,6 @@ class GemmConv2DKernel : public framework::OpKernel { framework::DDim output_matrix_shape = {output_channels, output_height * output_width}; - auto device_context = context.device_context(); - // convolution operator: im2col + gemm int in_step = input_channels / groups; int out_step = output_channels / groups; @@ -87,13 +85,13 @@ class GemmConv2DKernel : public framework::OpKernel { // im2col Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], - device_context); + context.device_context()); // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(device_context, filter_slice, false, col_matrix, - false, T(1.0), &out_slice, T(0.0)); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); } } } @@ -159,8 +157,6 @@ class GemmConvGrad2DKernel : public framework::OpKernel { filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); - auto device_context = context.device_context(); - // convolution backward input operator: gemm + col2im // convolution backward weight operator: im2col + gemm int in_step = input_channels / groups; @@ -182,7 +178,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel { out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(device_context, filter_slice, true, + math::matmul(context.device_context(), filter_slice, true, out_grad_slice, false, T(1.0), &col_matrix, T(0.0)); @@ -190,7 +186,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel { Tensor in_grad_slice = in_grad_batch.Slice(g * in_step, (g + 1) * in_step); col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], - paddings[1], device_context); + paddings[1], context.device_context()); } } } @@ -212,14 +208,14 @@ class GemmConvGrad2DKernel : public framework::OpKernel { out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(in_slice, col, strides[0], strides[1], paddings[0], - paddings[1], device_context); + paddings[1], context.device_context()); // gemm Tensor filter_grad_slice = filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(device_context, out_grad_slice, false, - col_matrix, true, T(1.0), &filter_grad_slice, - T(1.0)); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); } } } -- GitLab