提交 2947f567 编写于 作者: C chengduoZH

follow comments

上级 dc7d0735
......@@ -42,14 +42,20 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int output_height = col.dims()[3];
int output_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
int channels_col = input_channels * filter_height * filter_width;
......@@ -62,16 +68,14 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset;
int im_col_idx = w * stride_width + w_offset;
if ((im_row_idx - padding_up) < 0 ||
(im_row_idx - padding_up) >= input_height ||
(im_col_idx - padding_left) < 0 ||
(im_col_idx - padding_left) >= input_width) {
int im_row_idx = h * stride_height + h_offset - padding_up;
int im_col_idx = w * stride_width + w_offset - padding_left;
if (im_row_idx < 0 || im_row_idx >= input_height || im_col_idx < 0 ||
im_col_idx >= input_width) {
col_data[(c * output_height + h) * output_width + w] = T(0);
} else {
im_row_idx += c_im * input_height - padding_up;
im_col_idx -= padding_left;
im_row_idx += c_im * input_height;
col_data[(c * output_height + h) * output_width + w] =
im_data[im_row_idx * input_width + im_col_idx];
}
......@@ -104,14 +108,20 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int output_height = col.dims()[3];
int output_width = col.dims()[4];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
int channels_col = input_channels * filter_height * filter_width;
......@@ -124,14 +134,12 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
int c_im = c / filter_width / filter_height;
for (int h = 0; h < output_height; ++h) {
for (int w = 0; w < output_width; ++w) {
int im_row_idx = h * stride_height + h_offset;
int im_col_idx = w * stride_width + w_offset;
if ((im_row_idx - padding_up) >= 0 &&
(im_row_idx - padding_up) < input_height &&
(im_col_idx - padding_left) >= 0 &&
(im_col_idx - padding_left) < input_width) {
im_row_idx += c_im * input_height - padding_up;
im_col_idx -= padding_left;
int im_row_idx = h * stride_height + h_offset - padding_up;
int im_col_idx = w * stride_width + w_offset - padding_left;
if ((im_row_idx) >= 0 && (im_row_idx) < input_height &&
(im_col_idx) >= 0 && (im_col_idx) < input_width) {
im_row_idx += c_im * input_height;
im_data[im_row_idx * input_width + im_col_idx] +=
col_data[(c * output_height + h) * output_width + w];
}
......@@ -173,14 +181,20 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
const T* im_data = im.data<T>();
T* col_data = col.data<T>();
......@@ -243,14 +257,20 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kOCF,
int output_height = col.dims()[0];
int output_width = col.dims()[1];
PADDLE_ENFORCE((input_height + padding_up + padding_down - filter_height) /
stride_height +
1 ==
output_height);
PADDLE_ENFORCE((input_width + padding_left + padding_right - filter_width) /
stride_width +
1 ==
output_width);
PADDLE_ENFORCE_EQ(
(input_height + padding_up + padding_down - filter_height) /
stride_height +
1,
output_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(input_width + padding_left + padding_right - filter_width) /
stride_width +
1,
output_width,
"output_width and padding(padding_left, padding_right) are "
"inconsistent.");
T* im_data = im.data<T>();
const T* col_data = col.data<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册