提交 659f2f71 编写于 作者: H hedaoyuan

Bug fix for get device_context.

上级 d827359c
...@@ -75,8 +75,7 @@ class GemmConv2DKernel : public framework::OpKernel { ...@@ -75,8 +75,7 @@ class GemmConv2DKernel : public framework::OpKernel {
framework::DDim output_matrix_shape = {output_channels, framework::DDim output_matrix_shape = {output_channels,
output_height * output_width}; output_height * output_width};
auto* device_context = auto device_context = context.device_context();
const_cast<platform::DeviceContext*>(context.device_context_);
// convolution operator: im2col + gemm // convolution operator: im2col + gemm
int in_step = input_channels / groups; int in_step = input_channels / groups;
...@@ -93,8 +92,8 @@ class GemmConv2DKernel : public framework::OpKernel { ...@@ -93,8 +92,8 @@ class GemmConv2DKernel : public framework::OpKernel {
// gemm // gemm
Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice<T>(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(filter_slice, false, col_matrix, false, T(1.0), math::matmul<Place, T>(device_context, filter_slice, false, col_matrix,
&out_slice, T(0.0), device_context); false, T(1.0), &out_slice, T(0.0));
} }
} }
} }
...@@ -160,8 +159,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel { ...@@ -160,8 +159,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
filter.numel() / filter.dims()[0]}; filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
auto* device_context = auto device_context = context.device_context();
const_cast<platform::DeviceContext*>(context.device_context_);
// convolution backward input operator: gemm + col2im // convolution backward input operator: gemm + col2im
// convolution backward weight operator: im2col + gemm // convolution backward weight operator: im2col + gemm
...@@ -184,8 +182,9 @@ class GemmConvGrad2DKernel : public framework::OpKernel { ...@@ -184,8 +182,9 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step); out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor filter_slice = Tensor filter_slice =
filter.Slice<T>(g * out_step, (g + 1) * out_step); filter.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(filter_slice, true, out_grad_slice, false, math::matmul<Place, T>(device_context, filter_slice, true,
T(1.0), &col_matrix, T(0.0), device_context); out_grad_slice, false, T(1.0), &col_matrix,
T(0.0));
// col2im // col2im
Tensor in_grad_slice = Tensor in_grad_slice =
...@@ -218,9 +217,9 @@ class GemmConvGrad2DKernel : public framework::OpKernel { ...@@ -218,9 +217,9 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
// gemm // gemm
Tensor filter_grad_slice = Tensor filter_grad_slice =
filter_grad_.Slice<T>(g * out_step, (g + 1) * out_step); filter_grad_.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(out_grad_slice, false, col_matrix, true, math::matmul<Place, T>(device_context, out_grad_slice, false,
T(1.0), &filter_grad_slice, T(1.0), col_matrix, true, T(1.0), &filter_grad_slice,
device_context); T(1.0));
} }
} }
} }
......
...@@ -29,7 +29,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -29,7 +29,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) { int padding_width, const platform::DeviceContext& context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -81,7 +81,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -81,7 +81,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) { int padding_width, const platform::DeviceContext& context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -139,7 +139,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -139,7 +139,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) { int padding_width, const platform::DeviceContext& context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -199,7 +199,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -199,7 +199,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) { int padding_width, const platform::DeviceContext& context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
......
...@@ -66,7 +66,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -66,7 +66,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) { int padding_width, const platform::DeviceContext& context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -84,9 +84,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -84,9 +84,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int block_y = (blocks + 512 - 1) / 512; int block_y = (blocks + 512 - 1) / 512;
dim3 threads(1024, 1); dim3 threads(1024, 1);
dim3 grid(block_x, block_y); dim3 grid(block_x, block_y);
im2col<T><<< im2col<T><<<grid, threads, 0,
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(context)
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>( .stream()>>>(
im.data<T>(), num_outputs, input_height, input_width, filter_height, im.data<T>(), num_outputs, input_height, input_width, filter_height,
filter_width, stride_height, stride_width, padding_height, filter_width, stride_height, stride_width, padding_height,
padding_width, output_height, output_width, col.data<T>()); padding_width, output_height, output_width, col.data<T>());
...@@ -151,7 +151,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -151,7 +151,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) { int padding_width, const platform::DeviceContext& context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
...@@ -174,9 +174,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -174,9 +174,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
// To avoid involving atomic operations, we will launch one kernel per // To avoid involving atomic operations, we will launch one kernel per
// bottom dimension, and then in the kernel add up the top dimensions. // bottom dimension, and then in the kernel add up the top dimensions.
col2im<T><<< col2im<T><<<grid, threads, 0,
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(context)
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>( .stream()>>>(
num_kernels, col.data<T>(), input_height + 2 * padding_height, num_kernels, col.data<T>(), input_height + 2 * padding_height,
input_width + 2 * padding_width, input_channels, filter_height, input_width + 2 * padding_width, input_channels, filter_height,
filter_width, stride_height, stride_width, padding_height, filter_width, stride_height, stride_width, padding_height,
...@@ -237,7 +237,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -237,7 +237,7 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) { int padding_width, const platform::DeviceContext& context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -268,9 +268,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -268,9 +268,9 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
dim3 threads(block_dim_x, block_dim_y, dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels)); std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height); dim3 grid(output_width, output_height);
im2colOCF<T><<< im2colOCF<T><<<grid, threads, 0,
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(context)
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width, filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width); padding_height, padding_width, output_height, output_width);
...@@ -320,7 +320,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -320,7 +320,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context) { int padding_width, const platform::DeviceContext& context) {
PADDLE_ENFORCE(im.dims().size() == 3); PADDLE_ENFORCE(im.dims().size() == 3);
PADDLE_ENFORCE(col.dims().size() == 5); PADDLE_ENFORCE(col.dims().size() == 5);
int input_channels = im.dims()[0]; int input_channels = im.dims()[0];
...@@ -351,9 +351,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF, ...@@ -351,9 +351,9 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
dim3 threads(block_dim_x, block_dim_y, dim3 threads(block_dim_x, block_dim_y,
std::min(block_dim_z, input_channels)); std::min(block_dim_z, input_channels));
dim3 grid(output_width, output_height); dim3 grid(output_width, output_height);
col2imOCF<T><<< col2imOCF<T><<<grid, threads, 0,
grid, threads, 0, reinterpret_cast<const platform::CUDADeviceContext&>(context)
reinterpret_cast<platform::CUDADeviceContext*>(context)->stream()>>>( .stream()>>>(
im.data<T>(), col.data<T>(), input_channels, input_height, input_width, im.data<T>(), col.data<T>(), input_channels, input_height, input_width,
filter_height, filter_width, stride_height, stride_width, filter_height, filter_width, stride_height, stride_width,
padding_height, padding_width, output_height, output_width); padding_height, padding_width, output_height, output_width);
......
...@@ -74,7 +74,7 @@ class Im2ColFunctor { ...@@ -74,7 +74,7 @@ class Im2ColFunctor {
public: public:
void operator()(const framework::Tensor& im, framework::Tensor& col, void operator()(const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context); int padding_width, const platform::DeviceContext& context);
}; };
template <ColFormat Format, typename Place, typename T> template <ColFormat Format, typename Place, typename T>
...@@ -82,7 +82,7 @@ class Col2ImFunctor { ...@@ -82,7 +82,7 @@ class Col2ImFunctor {
public: public:
void operator()(framework::Tensor& im, const framework::Tensor& col, void operator()(framework::Tensor& im, const framework::Tensor& col,
int stride_height, int stride_width, int padding_height, int stride_height, int stride_width, int padding_height,
int padding_width, platform::DeviceContext* context); int padding_width, const platform::DeviceContext& context);
}; };
} // namespace math } // namespace math
......
...@@ -78,8 +78,8 @@ void testIm2col() { ...@@ -78,8 +78,8 @@ void testIm2col() {
PADDLE_THROW("no GPU support"); PADDLE_THROW("no GPU support");
#endif // PADDLE_ONLY_CPU #endif // PADDLE_ONLY_CPU
} }
im2col(input, output_cfo, stride, stride, padding, padding, context); im2col(input, output_cfo, stride, stride, padding, padding, *context);
im2col_ocf(input, output_ocf, stride, stride, padding, padding, context); im2col_ocf(input, output_ocf, stride, stride, padding, padding, *context);
float* out_cfo_ptr; float* out_cfo_ptr;
if (paddle::platform::is_cpu_place(*place)) { if (paddle::platform::is_cpu_place(*place)) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册