提交 bb546cf1 编写于 作者: H hedaoyuan

Bug fix.

上级 659f2f71
...@@ -75,8 +75,6 @@ class GemmConv2DKernel : public framework::OpKernel { ...@@ -75,8 +75,6 @@ class GemmConv2DKernel : public framework::OpKernel {
framework::DDim output_matrix_shape = {output_channels, framework::DDim output_matrix_shape = {output_channels,
output_height * output_width}; output_height * output_width};
auto device_context = context.device_context();
// convolution operator: im2col + gemm // convolution operator: im2col + gemm
int in_step = input_channels / groups; int in_step = input_channels / groups;
int out_step = output_channels / groups; int out_step = output_channels / groups;
...@@ -87,13 +85,13 @@ class GemmConv2DKernel : public framework::OpKernel { ...@@ -87,13 +85,13 @@ class GemmConv2DKernel : public framework::OpKernel {
// im2col // im2col
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1], im2col(in_slice, col, strides[0], strides[1], paddings[0], paddings[1],
device_context); context.device_context());
// gemm // gemm
Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step); Tensor out_slice = out_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor filter_slice = filter.Slice<T>(g * out_step, (g + 1) * out_step); Tensor filter_slice = filter.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(device_context, filter_slice, false, col_matrix, math::matmul<Place, T>(context.device_context(), filter_slice, false,
false, T(1.0), &out_slice, T(0.0)); col_matrix, false, T(1.0), &out_slice, T(0.0));
} }
} }
} }
...@@ -159,8 +157,6 @@ class GemmConvGrad2DKernel : public framework::OpKernel { ...@@ -159,8 +157,6 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
filter.numel() / filter.dims()[0]}; filter.numel() / filter.dims()[0]};
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
auto device_context = context.device_context();
// convolution backward input operator: gemm + col2im // convolution backward input operator: gemm + col2im
// convolution backward weight operator: im2col + gemm // convolution backward weight operator: im2col + gemm
int in_step = input_channels / groups; int in_step = input_channels / groups;
...@@ -182,7 +178,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel { ...@@ -182,7 +178,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step); out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor filter_slice = Tensor filter_slice =
filter.Slice<T>(g * out_step, (g + 1) * out_step); filter.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(device_context, filter_slice, true, math::matmul<Place, T>(context.device_context(), filter_slice, true,
out_grad_slice, false, T(1.0), &col_matrix, out_grad_slice, false, T(1.0), &col_matrix,
T(0.0)); T(0.0));
...@@ -190,7 +186,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel { ...@@ -190,7 +186,7 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
Tensor in_grad_slice = Tensor in_grad_slice =
in_grad_batch.Slice<T>(g * in_step, (g + 1) * in_step); in_grad_batch.Slice<T>(g * in_step, (g + 1) * in_step);
col2im(in_grad_slice, col, strides[0], strides[1], paddings[0], col2im(in_grad_slice, col, strides[0], strides[1], paddings[0],
paddings[1], device_context); paddings[1], context.device_context());
} }
} }
} }
...@@ -212,14 +208,14 @@ class GemmConvGrad2DKernel : public framework::OpKernel { ...@@ -212,14 +208,14 @@ class GemmConvGrad2DKernel : public framework::OpKernel {
out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step); out_grad_batch.Slice<T>(g * out_step, (g + 1) * out_step);
Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step); Tensor in_slice = in_batch.Slice<T>(g * in_step, (g + 1) * in_step);
im2col(in_slice, col, strides[0], strides[1], paddings[0], im2col(in_slice, col, strides[0], strides[1], paddings[0],
paddings[1], device_context); paddings[1], context.device_context());
// gemm // gemm
Tensor filter_grad_slice = Tensor filter_grad_slice =
filter_grad_.Slice<T>(g * out_step, (g + 1) * out_step); filter_grad_.Slice<T>(g * out_step, (g + 1) * out_step);
math::matmul<Place, T>(device_context, out_grad_slice, false, math::matmul<Place, T>(context.device_context(), out_grad_slice,
col_matrix, true, T(1.0), &filter_grad_slice, false, col_matrix, true, T(1.0),
T(1.0)); &filter_grad_slice, T(1.0));
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册