未验证 提交 ecc3098e 编写于 作者: Y yeliang2258 提交者: GitHub

fix vol2col (#44998)

上级 94c17a0f
......@@ -50,18 +50,18 @@ class Vol2ColFunctor<phi::CPUContext, T> {
int input_channels =
(data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
int input_depth =
int64_t input_depth =
(data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
int input_height =
int64_t input_height =
(data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
int input_width =
int64_t input_width =
(data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
int output_depth = col->dims()[4];
int output_height = col->dims()[5];
int output_width = col->dims()[6];
int64_t output_depth = col->dims()[4];
int64_t output_height = col->dims()[5];
int64_t output_width = col->dims()[6];
int channels_col =
input_channels * filter_depth * filter_height * filter_width;
......@@ -109,22 +109,22 @@ class Vol2ColFunctor<phi::CPUContext, T> {
output_width));
const T* vol_data = vol.data<T>();
T* col_data = col->data<T>();
for (int c = 0; c < channels_col; ++c) {
for (auto c = 0; c < channels_col; ++c) {
int w_offset = c % filter_width;
int h_offset = (c / filter_width) % filter_height;
int d_offset = (c / filter_width / filter_height) % filter_depth;
int c_in = c / filter_width / filter_height / filter_depth;
for (int d = 0; d < output_depth; ++d) {
int d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
for (int h = 0; h < output_height; ++h) {
int h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
for (int w = 0; w < output_width; ++w) {
int w_pad = w * strides[2] - pad_w_left + w_offset * dilations[2];
int64_t c_in = c / filter_width / filter_height / filter_depth;
for (auto d = 0; d < output_depth; ++d) {
int64_t d_pad = d * strides[0] - pad_d_forth + d_offset * dilations[0];
for (auto h = 0; h < output_height; ++h) {
int64_t h_pad = h * strides[1] - pad_h_up + h_offset * dilations[1];
for (auto w = 0; w < output_width; ++w) {
int64_t w_pad =
w * strides[2] - pad_w_left + w_offset * dilations[2];
int col_idx =
int64_t col_idx =
((c * output_depth + d) * output_height + h) * output_width + w;
int vol_idx;
int64_t vol_idx;
if (data_layout != DataLayout::kNHWC) {
vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) *
input_width +
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册