From 78910480c198f0aba3d2503e72ca410b1e5fe537 Mon Sep 17 00:00:00 2001 From: Zhang Ting Date: Sun, 13 Oct 2019 17:41:18 +0800 Subject: [PATCH] fix conv_transpose's bug: compatible with Anylayout setting, test=develop (#20589) --- paddle/fluid/operators/conv_transpose_op.cc | 8 ++--- paddle/fluid/operators/conv_transpose_op.h | 34 ++++++++++---------- paddle/fluid/operators/math/im2col.cc | 6 ++-- paddle/fluid/operators/math/im2col_cfo_cpu.h | 30 ++++++++--------- paddle/fluid/operators/math/vol2col.cc | 20 ++++++------ paddle/fluid/operators/math/vol2col.cu | 26 +++++++-------- 6 files changed, 62 insertions(+), 62 deletions(-) diff --git a/paddle/fluid/operators/conv_transpose_op.cc b/paddle/fluid/operators/conv_transpose_op.cc index 3758a3c079..f0962a74b0 100644 --- a/paddle/fluid/operators/conv_transpose_op.cc +++ b/paddle/fluid/operators/conv_transpose_op.cc @@ -64,7 +64,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { "dimension should be the same."); const int64_t C = - (data_layout == DataLayout::kNCHW ? in_dims[1] + (data_layout != DataLayout::kNHWC ? in_dims[1] : in_dims[in_dims.size() - 1]); PADDLE_ENFORCE_EQ( C, filter_dims[0], @@ -72,7 +72,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { "be equal to the number of filter's channels."); framework::DDim in_data_dims; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); } else { in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); @@ -84,10 +84,10 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { in_data_dims, strides, ksize); std::vector output_shape({in_dims[0]}); - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { output_shape.push_back(filter_dims[1] * groups); } - const int offset = (data_layout == DataLayout::kNCHW ? 2 : 1); + const int offset = (data_layout != DataLayout::kNHWC ? 2 : 1); for (size_t i = 0; i < strides.size(); ++i) { auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1; auto infer_shape = (in_dims[i + offset] - 1) * strides[i] - diff --git a/paddle/fluid/operators/conv_transpose_op.h b/paddle/fluid/operators/conv_transpose_op.h index 56cfa8618f..8bcef8e49a 100644 --- a/paddle/fluid/operators/conv_transpose_op.h +++ b/paddle/fluid/operators/conv_transpose_op.h @@ -176,7 +176,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); framework::DDim in_data_dims; - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); } else { in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); @@ -198,7 +198,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { // col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w} size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { col_shape_vec[0] = out_dims[1] / groups; for (size_t j = 0; j < data_dim; ++j) { col_shape_vec[j + 1] = filter_shape_vec[j + 2]; @@ -234,7 +234,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last DDim input_matrix_shape; - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { input_matrix_shape = {in_dims[1], col_matrix_shape[1]}; } else { input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]}; @@ -242,7 +242,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w) DDim filter_matrix_shape; - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { filter_matrix_shape = {in_dims[1], col_matrix_shape[0]}; } else { filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]}; @@ -256,12 +256,12 @@ class GemmConvTransposeKernel : public framework::OpKernel { set_zero(dev_ctx, output, static_cast(0)); int in_step = - (data_layout == framework::DataLayout::kNCHW + (data_layout != framework::DataLayout::kNHWC ? static_cast(in_dims[1]) / groups : static_cast(in_dims[in_dims.size() - 1]) / groups); int out_step = - (data_layout == framework::DataLayout::kNCHW + (data_layout != framework::DataLayout::kNHWC ? static_cast(out_dims[1]) / groups : static_cast(out_dims[out_dims.size() - 1]) / groups); math::Col2ImFunctor col2im; @@ -284,14 +284,14 @@ class GemmConvTransposeKernel : public framework::OpKernel { for (int g = 0; g < groups; g++) { int64_t start = g * in_step; int64_t end = (g + 1) * in_step; - int axes = (data_layout == framework::DataLayout::kNCHW ? 0 : 1); + int axes = (data_layout != framework::DataLayout::kNHWC ? 0 : 1); Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step); Tensor in_slice, out_slice; // col_matrix = filter_slice * input_slice // of shape (o_c/g * k_h * k_w, h * w) // or (o_c/g * k_d * k_h * k_w, d * h * w) - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step); out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step); blas.MatMul(filter_slice, true, in_slice, false, static_cast(1.0), @@ -372,7 +372,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); framework::DDim in_data_dims; - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); } else { in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); @@ -394,7 +394,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for size_t data_dim = filter_shape_vec.size() - 2; std::vector col_shape_vec(1 + 2 * data_dim); - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { col_shape_vec[0] = out_grad_dims[1]; for (size_t j = 0; j < data_dim; ++j) { col_shape_vec[j + 1] = filter_shape_vec[j + 2]; @@ -421,7 +421,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last DDim input_matrix_shape; - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { input_matrix_shape = {in_dims[1], col_matrix_shape[1]}; } else { input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]}; @@ -429,7 +429,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w) DDim filter_matrix_shape; - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups}; } else { filter_matrix_shape = {in_dims[in_dims.size() - 1], @@ -438,7 +438,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { filter.Resize(filter_matrix_shape); int in_step = - (data_layout == framework::DataLayout::kNCHW + (data_layout != framework::DataLayout::kNHWC ? static_cast(in_dims[1]) / groups : static_cast(in_dims[in_dims.size() - 1]) / groups); int col_step = static_cast(col_matrix_shape[0]) / groups; @@ -531,7 +531,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { // k_h * k_w, d * h * w) Tensor col_matrix_slice = col_matrix.Slice(g * col_step, (g + 1) * col_step); - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { Tensor input_grad_slice = input_grad_batch.Slice(g * in_step, (g + 1) * in_step); blas.MatMul(filter_slice, false, col_matrix_slice, false, @@ -579,7 +579,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { filter_grad_.Slice(g * in_step, (g + 1) * in_step); Tensor col_matrix_slice = col_matrix.Slice(g * col_step, (g + 1) * col_step); - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { Tensor in_batch_slice = in_batch.Slice(g * in_step, (g + 1) * in_step); blas.MatMul(in_batch_slice, false, col_matrix_slice, true, @@ -629,7 +629,7 @@ class DepthwiseConvTransposeKernel : public framework::OpKernel { auto filter_dims = filter.dims(); framework::DDim in_data_dims; - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); } else { in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); @@ -684,7 +684,7 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel { auto filter_dims = filter.dims(); framework::DDim in_data_dims; - if (data_layout == framework::DataLayout::kNCHW) { + if (data_layout != framework::DataLayout::kNHWC) { in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); } else { in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); diff --git a/paddle/fluid/operators/math/im2col.cc b/paddle/fluid/operators/math/im2col.cc index 4736c78fe9..835998fdce 100644 --- a/paddle/fluid/operators/math/im2col.cc +++ b/paddle/fluid/operators/math/im2col.cc @@ -74,11 +74,11 @@ class Col2ImFunctordims()[0] : im->dims()[2]); + (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]); int im_height = - (data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]); + (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]); int im_width = - (data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]); + (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]); int filter_height = col.dims()[1]; int filter_width = col.dims()[2]; int col_height = col.dims()[3]; diff --git a/paddle/fluid/operators/math/im2col_cfo_cpu.h b/paddle/fluid/operators/math/im2col_cfo_cpu.h index bd42bd1a18..01f1e220e6 100644 --- a/paddle/fluid/operators/math/im2col_cfo_cpu.h +++ b/paddle/fluid/operators/math/im2col_cfo_cpu.h @@ -33,11 +33,11 @@ inline void im2col_common(const framework::Tensor& im, framework::Tensor* col, const DataLayout data_layout = DataLayout::kNCHW) { int im_channels = - (data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]); + (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]); int im_height = - (data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]); + (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]); int im_width = - (data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]); + (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]); int filter_height = col->dims()[1]; int filter_width = col->dims()[2]; int output_height = col->dims()[3]; @@ -55,7 +55,7 @@ inline void im2col_common(const framework::Tensor& im, for (int w = 0; w < output_width; ++w) { int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; int im_idx; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; } else { im_idx = (im_row_idx * im_width + im_col_idx) * im_channels + c_im; @@ -79,11 +79,11 @@ inline void im2col_sh1sw1dh1dw1ph0pw0( const framework::Tensor& im, framework::Tensor* col, const DataLayout data_layout = DataLayout::kNCHW) { int im_channels = - (data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]); + (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]); int im_height = - (data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]); + (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]); int im_width = - (data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]); + (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]); int filter_height = col->dims()[1]; int filter_width = col->dims()[2]; int output_height = col->dims()[3]; @@ -103,7 +103,7 @@ inline void im2col_sh1sw1dh1dw1ph0pw0( const T* src_data = src_data_ic; for (int kh = 0; kh < filter_height; ++kh) { for (int kw = 0; kw < filter_width; ++kw) { - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { std::memcpy(dst_data, src_data + kw, copy_size); } else { for (int kow = 0; kow < output_width; ++kow) { @@ -131,11 +131,11 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, framework::Tensor* col, const DataLayout data_layout) { int im_channels = - (data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]); + (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]); int im_height = - (data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]); + (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]); int im_width = - (data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]); + (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]); int filter_height = col->dims()[1]; int filter_width = col->dims()[2]; int output_height = col->dims()[3]; @@ -205,7 +205,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, dst_data = dst_data + col_matrix_width; continue; } - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { std::memcpy(dst_data + plw, src_data, copy_size); } else { for (int kow = 0; kow < output_width - plw - prw; ++kow) { @@ -261,7 +261,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, // TODO(TJ): reuse plw-kw outside this for // try to unify for (int kw = 0; kw < plw; ++kw) { - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { std::memcpy(dst_data + (plw - kw), src_data, sizeof(T) * (output_width - (plw - kw))); } else { @@ -276,7 +276,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, dst_data = dst_data + col_matrix_width; } for (int kw = plw; kw < filter_width - prw; ++kw) { - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { std::memcpy(dst_data, src_data + (kw - plw), sizeof(T) * output_width); } else { @@ -292,7 +292,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, } int i = 1; for (int kw = filter_width - prw; kw < filter_width; ++kw, ++i) { - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { std::memcpy(dst_data, src_data + (kw - plw), sizeof(T) * (output_width - i)); } else { diff --git a/paddle/fluid/operators/math/vol2col.cc b/paddle/fluid/operators/math/vol2col.cc index da051034da..01f50727b4 100644 --- a/paddle/fluid/operators/math/vol2col.cc +++ b/paddle/fluid/operators/math/vol2col.cc @@ -40,13 +40,13 @@ class Vol2ColFunctor { "The dimension of col should be 7."); int input_channels = - (data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]); + (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]); int input_depth = - (data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]); + (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]); int input_height = - (data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]); + (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]); int input_width = - (data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]); + (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]); int filter_depth = col->dims()[1]; int filter_height = col->dims()[2]; int filter_width = col->dims()[3]; @@ -104,7 +104,7 @@ class Vol2ColFunctor { int col_idx = ((c * output_depth + d) * output_height + h) * output_width + w; int vol_idx; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) * input_width + w_pad; @@ -146,13 +146,13 @@ class Col2VolFunctor { "The dimension of col should be 7."); int input_channels = - (data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]); + (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]); int input_depth = - (data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]); + (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]); int input_height = - (data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]); + (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]); int input_width = - (data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]); + (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]); int filter_depth = col.dims()[1]; int filter_height = col.dims()[2]; int filter_width = col.dims()[3]; @@ -209,7 +209,7 @@ class Col2VolFunctor { if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { int vol_idx; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) * input_width + w_pad; diff --git a/paddle/fluid/operators/math/vol2col.cu b/paddle/fluid/operators/math/vol2col.cu index b42dd55bda..9de9051f51 100644 --- a/paddle/fluid/operators/math/vol2col.cu +++ b/paddle/fluid/operators/math/vol2col.cu @@ -55,7 +55,7 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth, int h = h_in + i * dilation_h; int w = w_in + j * dilation_w; int vol_idx; - if (data_layout == DataLayout::kNCHW) { + if (data_layout != DataLayout::kNHWC) { vol_idx = ((channel_in * depth + d) * height + h) * width + w; } else { vol_idx = @@ -96,13 +96,13 @@ class Vol2ColFunctor { "The dimension of col should be 7."); int input_channels = - (data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]); + (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]); int input_depth = - (data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]); + (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]); int input_height = - (data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]); + (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]); int input_width = - (data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]); + (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]); int filter_depth = col->dims()[1]; int filter_height = col->dims()[2]; int filter_width = col->dims()[3]; @@ -170,16 +170,16 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; index += blockDim.x * gridDim.x) { T src_val = 0; - int w = (data_layout == DataLayout::kNCHW + int w = (data_layout != DataLayout::kNHWC ? index % width + padding_width : (index / input_channels) % width + padding_width); - int h = (data_layout == DataLayout::kNCHW + int h = (data_layout != DataLayout::kNHWC ? (index / width) % height + padding_height : (index / input_channels / width) % height + padding_height); - int d = (data_layout == DataLayout::kNCHW + int d = (data_layout != DataLayout::kNHWC ? (index / width / height) % depth + padding_depth : index / input_channels / width / height + padding_depth); - int c = (data_layout == DataLayout::kNCHW ? index / width / height / depth + int c = (data_layout != DataLayout::kNHWC ? index / width / height / depth : index % input_channels); // compute the start and end of the output @@ -247,13 +247,13 @@ class Col2VolFunctor { "The dimension of col should be 7."); int input_channels = - (data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]); + (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]); int input_depth = - (data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]); + (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]); int input_height = - (data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]); + (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]); int input_width = - (data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]); + (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]); int filter_depth = col.dims()[1]; int filter_height = col.dims()[2]; int filter_width = col.dims()[3]; -- GitLab