diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h index 72de0a5cf3e640625d39edf45af623f131699ad1..f9215b46fcd9f091f88865df721481698ac0c8e7 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -75,8 +75,6 @@ class GemmConv2DKernel : public framework::OpKernel { framework::DDim output_matrix_shape = {output_channels, output_height * output_width}; - auto device_context = context.device_context(); - // convolution operator: im2col + gemm int in_step = input_channels / groups; int out_step = output_channels / groups; @@ -87,13 +85,13 @@ class GemmConv2DKernel : public framework::OpKernel { // im2col Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], - device_context); + context.device_context()); // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(device_context, filter_slice, false, col_matrix, - false, T(1.0), &out_slice, T(0.0)); + math::matmul(context.device_context(), filter_slice, false, + col_matrix, false, T(1.0), &out_slice, T(0.0)); } } } @@ -159,8 +157,6 @@ class GemmConvGrad2DKernel : public framework::OpKernel { filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); - auto device_context = context.device_context(); - // convolution backward input operator: gemm + col2im // convolution backward weight operator: im2col + gemm int in_step = input_channels / groups; @@ -182,7 +178,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel { out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(device_context, filter_slice, true, + math::matmul(context.device_context(), filter_slice, true, out_grad_slice, false, T(1.0), &col_matrix, T(0.0)); @@ -190,7 +186,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel { Tensor in_grad_slice = in_grad_batch.Slice(g * in_step, (g + 1) * in_step); col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], - paddings[1], device_context); + paddings[1], context.device_context()); } } } @@ -212,14 +208,14 @@ class GemmConvGrad2DKernel : public framework::OpKernel { out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor in_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); im2col(in_slice, col, strides[0], strides[1], paddings[0], - paddings[1], device_context); + paddings[1], context.device_context()); // gemm Tensor filter_grad_slice = filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(device_context, out_grad_slice, false, - col_matrix, true, T(1.0), &filter_grad_slice, - T(1.0)); + math::matmul(context.device_context(), out_grad_slice, + false, col_matrix, true, T(1.0), + &filter_grad_slice, T(1.0)); } } }