提交 200a02ec 编写于 作者: C chengduo 提交者: GitHub

Merge pull request #5041 from chengduoZH/fix_im2col_interface

fix im2col interface
...@@ -120,7 +120,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> { ...@@ -120,7 +120,7 @@ class GemmConv2DTransposeKernel : public framework::OpKernel<T> {
math::matmul<Place, T>(context.device_context(), filter, true, math::matmul<Place, T>(context.device_context(), filter, true,
input_batch, false, T(1.0), &col_matrix, T(0.0)); input_batch, false, T(1.0), &col_matrix, T(0.0));
col2im(context.device_context(), output_batch, col, strides[0], col2im(context.device_context(), output_batch, col, strides[0],
strides[1], 0, 0); strides[1], 0, 0, 0, 0);
} }
} }
}; };
...@@ -206,7 +206,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -206,7 +206,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
// im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w) // im2col: dy from (c, o_h, o_w) -> (c * k_h * k_w, h * w)
im2col(context.device_context(), output_grad_batch, col, strides[0], im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[1]); strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm: dx = filter * dy // gemm: dx = filter * dy
// (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h) // (m, c * k_h * k_w) * (c * k_h * k_w, h * w) -> (m, c, h)
...@@ -238,7 +238,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> { ...@@ -238,7 +238,7 @@ class GemmConv2DTransposeGradKernel : public framework::OpKernel<T> {
// im2col: (c * h * w, k_h * k_w) // im2col: (c * h * w, k_h * k_w)
im2col(context.device_context(), output_grad_batch, col, strides[0], im2col(context.device_context(), output_grad_batch, col, strides[0],
strides[1], paddings[0], paddings[1]); strides[1], paddings[0], paddings[0], paddings[1], paddings[1]);
// gemm: d_filter = x * y_grad^T // gemm: d_filter = x * y_grad^T
// (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h) // (m, c * h * w) * (k_h * k_w, c * h * w) -> (m, c, h)
......
...@@ -75,8 +75,7 @@ class Im2ColFunctor { ...@@ -75,8 +75,7 @@ class Im2ColFunctor {
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& im, framework::Tensor& col, const framework::Tensor& im, framework::Tensor& col,
int stride_height, int stride_width, int padding_up, int stride_height, int stride_width, int padding_up,
int padding_down, int padding_left = 0, int padding_down, int padding_left, int padding_right);
int padding_right = 0);
}; };
template <ColFormat Format, typename Place, typename T> template <ColFormat Format, typename Place, typename T>
...@@ -85,7 +84,7 @@ class Col2ImFunctor { ...@@ -85,7 +84,7 @@ class Col2ImFunctor {
void operator()(const platform::DeviceContext& context, framework::Tensor& im, void operator()(const platform::DeviceContext& context, framework::Tensor& im,
const framework::Tensor& col, int stride_height, const framework::Tensor& col, int stride_height,
int stride_width, int padding_up, int padding_down, int stride_width, int padding_up, int padding_down,
int padding_left = 0, int padding_right = 0); int padding_left, int padding_right);
}; };
} // namespace math } // namespace math
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册