From 659f2f71ac62434485675ce6cc1403fe4409c589 Mon Sep 17 00:00:00 2001 From: hedaoyuan Date: Thu, 21 Sep 2017 13:29:42 +0800 Subject: [PATCH] Bug fix for get device_context. --- paddle/operators/gemm_conv2d_op.h | 21 +++++++++--------- paddle/operators/math/im2col.cc | 8 +++---- paddle/operators/math/im2col.cu | 32 ++++++++++++++-------------- paddle/operators/math/im2col.h | 4 ++-- paddle/operators/math/im2col_test.cc | 4 ++-- 5 files changed, 34 insertions(+), 35 deletions(-) diff --git a/paddle/operators/gemm_conv2d_op.h b/paddle/operators/gemm_conv2d_op.h index 08b7df1dfe..72de0a5cf3 100644 --- a/paddle/operators/gemm_conv2d_op.h +++ b/paddle/operators/gemm_conv2d_op.h @@ -75,8 +75,7 @@ class GemmConv2DKernel : public framework::OpKernel { framework::DDim output_matrix_shape = {output_channels, output_height * output_width}; - auto* device_context = - const_cast(context.device_context_); + auto device_context = context.device_context(); // convolution operator: im2col + gemm int in_step = input_channels / groups; @@ -93,8 +92,8 @@ class GemmConv2DKernel : public framework::OpKernel { // gemm Tensor out_slice = out_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, false, col_matrix, false, T(1.0), - &out_slice, T(0.0), device_context); + math::matmul(device_context, filter_slice, false, col_matrix, + false, T(1.0), &out_slice, T(0.0)); } } } @@ -160,8 +159,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel { filter.numel() / filter.dims()[0]}; filter.Resize(filter_matrix_shape); - auto* device_context = - const_cast(context.device_context_); + auto device_context = context.device_context(); // convolution backward input operator: gemm + col2im // convolution backward weight operator: im2col + gemm @@ -184,8 +182,9 @@ class GemmConvGrad2DKernel : public framework::OpKernel { out_grad_batch.Slice(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice(g * out_step, (g + 1) * out_step); - math::matmul(filter_slice, true, out_grad_slice, false, - T(1.0), &col_matrix, T(0.0), device_context); + math::matmul(device_context, filter_slice, true, + out_grad_slice, false, T(1.0), &col_matrix, + T(0.0)); // col2im Tensor in_grad_slice = @@ -218,9 +217,9 @@ class GemmConvGrad2DKernel : public framework::OpKernel { // gemm Tensor filter_grad_slice = filter_grad_.Slice(g * out_step, (g + 1) * out_step); - math::matmul(out_grad_slice, false, col_matrix, true, - T(1.0), &filter_grad_slice, T(1.0), - device_context); + math::matmul(device_context, out_grad_slice, false, + col_matrix, true, T(1.0), &filter_grad_slice, + T(1.0)); } } } diff --git a/paddle/operators/math/im2col.cc b/paddle/operators/math/im2col.cc index 5727c1cab1..36a07f7a31 100644 --- a/paddle/operators/math/im2col.cc +++ b/paddle/operators/math/im2col.cc @@ -29,7 +29,7 @@ class Im2ColFunctor<<< - grid, threads, 0, - reinterpret_cast(context)->stream()>>>( + im2col<<(context) + .stream()>>>( im.data(), num_outputs, input_height, input_width, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, output_height, output_width, col.data()); @@ -151,7 +151,7 @@ class Col2ImFunctor<<< - grid, threads, 0, - reinterpret_cast(context)->stream()>>>( + col2im<<(context) + .stream()>>>( num_kernels, col.data(), input_height + 2 * padding_height, input_width + 2 * padding_width, input_channels, filter_height, filter_width, stride_height, stride_width, padding_height, @@ -237,7 +237,7 @@ class Im2ColFunctor<<< - grid, threads, 0, - reinterpret_cast(context)->stream()>>>( + im2colOCF<<(context) + .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, output_height, output_width); @@ -320,7 +320,7 @@ class Col2ImFunctor<<< - grid, threads, 0, - reinterpret_cast(context)->stream()>>>( + col2imOCF<<(context) + .stream()>>>( im.data(), col.data(), input_channels, input_height, input_width, filter_height, filter_width, stride_height, stride_width, padding_height, padding_width, output_height, output_width); diff --git a/paddle/operators/math/im2col.h b/paddle/operators/math/im2col.h index 8958c5457c..9a119c6894 100644 --- a/paddle/operators/math/im2col.h +++ b/paddle/operators/math/im2col.h @@ -74,7 +74,7 @@ class Im2ColFunctor { public: void operator()(const framework::Tensor& im, framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context); + int padding_width, const platform::DeviceContext& context); }; template @@ -82,7 +82,7 @@ class Col2ImFunctor { public: void operator()(framework::Tensor& im, const framework::Tensor& col, int stride_height, int stride_width, int padding_height, - int padding_width, platform::DeviceContext* context); + int padding_width, const platform::DeviceContext& context); }; } // namespace math diff --git a/paddle/operators/math/im2col_test.cc b/paddle/operators/math/im2col_test.cc index 4f380388b1..e0943c0379 100644 --- a/paddle/operators/math/im2col_test.cc +++ b/paddle/operators/math/im2col_test.cc @@ -78,8 +78,8 @@ void testIm2col() { PADDLE_THROW("no GPU support"); #endif // PADDLE_ONLY_CPU } - im2col(input, output_cfo, stride, stride, padding, padding, context); - im2col_ocf(input, output_ocf, stride, stride, padding, padding, context); + im2col(input, output_cfo, stride, stride, padding, padding, *context); + im2col_ocf(input, output_ocf, stride, stride, padding, padding, *context); float* out_cfo_ptr; if (paddle::platform::is_cpu_place(*place)) { -- GitLab