提交 c18f1bd7 编写于 作者: Z Zhang Ting 提交者: Aurelius84

fix the bug of conv_transpose:compatible with Anylayout setting, test=develop (#20897)

上级 3358455c
...@@ -316,7 +316,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> { ...@@ -316,7 +316,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
int user_workspace_size = ctx.Attr<int>("workspace_size_MB"); int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
const std::string data_layout_str = ctx.Attr<std::string>("data_format"); const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const paddle::operators::DataLayout data_layout = const paddle::operators::DataLayout data_layout =
(data_layout_str == "NCHW" ? DataLayout::kNCHW : DataLayout::kNHWC); (data_layout_str != "NHWC" ? DataLayout::kNCHW : DataLayout::kNHWC);
// if channel_last, transpose to channel_first // if channel_last, transpose to channel_first
Tensor input_transpose; Tensor input_transpose;
......
...@@ -328,7 +328,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -328,7 +328,9 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice, col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice,
data_layout); data_layout);
} }
output_batch_vec.push_back(out_slice); if (data_layout == framework::DataLayout::kNHWC) {
output_batch_vec.push_back(out_slice);
}
} }
if (data_layout == framework::DataLayout::kNHWC) { if (data_layout == framework::DataLayout::kNHWC) {
concat_functor(dev_ctx, output_batch_vec, static_cast<int>(D - 2), concat_functor(dev_ctx, output_batch_vec, static_cast<int>(D - 2),
......
...@@ -60,7 +60,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -60,7 +60,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
const int w_in_end = w_in_start + filter_width * dilate_width; const int w_in_end = w_in_start + filter_width * dilate_width;
int in_offset; int in_offset;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
in_offset = in_offset =
((batch * input_channels + c_in) * input_height) * input_width; ((batch * input_channels + c_in) * input_height) * input_width;
} else { } else {
...@@ -78,7 +78,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -78,7 +78,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
if (h_in >= h_start && h_in < h_end && w_in >= w_start && if (h_in >= h_start && h_in < h_end && w_in >= w_start &&
w_in < w_end) { w_in < w_end) {
int offset; int offset;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
offset = in_offset + h_in * input_width + w_in; offset = in_offset + h_in * input_width + w_in;
} else { } else {
offset = in_offset + offset = in_offset +
...@@ -94,7 +94,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) { ...@@ -94,7 +94,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
} }
} }
int index; int index;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
index = ((batch * gridDim.x + c_out) * output_height + h_out) * index = ((batch * gridDim.x + c_out) * output_height + h_out) *
output_width + output_width +
w_out; w_out;
...@@ -131,7 +131,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( ...@@ -131,7 +131,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
const int w_in_end = w_in_start + c_filter * dilate_width; const int w_in_end = w_in_start + c_filter * dilate_width;
int in_offset; int in_offset;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
in_offset = in_offset =
((batch * input_channels + c_in) * input_height) * input_width; ((batch * input_channels + c_in) * input_height) * input_width;
} else { } else {
...@@ -150,7 +150,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( ...@@ -150,7 +150,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
if (h_in >= 0 && h_in < input_height && w_in >= 0 && if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
w_in < input_width) { w_in < input_width) {
int offset; int offset;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
offset = in_offset + h_in * input_width + w_in; offset = in_offset + h_in * input_width + w_in;
} else { } else {
offset = in_offset + offset = in_offset +
...@@ -166,7 +166,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter( ...@@ -166,7 +166,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
} }
} }
int index; int index;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
index = ((batch * gridDim.x + c_out) * output_height + h_out) * index = ((batch * gridDim.x + c_out) * output_height + h_out) *
output_width + output_width +
w_out; w_out;
...@@ -252,7 +252,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( ...@@ -252,7 +252,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
T value = 0; T value = 0;
int index; int index;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
index = index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in; w_in;
...@@ -283,7 +283,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad( ...@@ -283,7 +283,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
s_w_out < output_width) { s_w_out < output_width) {
int output_grad_offset; int output_grad_offset;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
output_grad_offset = output_grad_offset =
((batch * output_channels + c_out) * output_height + ((batch * output_channels + c_out) * output_height +
s_h_out) * s_h_out) *
...@@ -335,7 +335,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter( ...@@ -335,7 +335,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
T value = 0; T value = 0;
int index; int index;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
index = index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width + ((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in; w_in;
...@@ -363,7 +363,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter( ...@@ -363,7 +363,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 && s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
s_w_out < output_width) { s_w_out < output_width) {
int output_grad_offset; int output_grad_offset;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
output_grad_offset = output_grad_offset =
((batch * output_channels + c_out) * output_height + ((batch * output_channels + c_out) * output_height +
s_h_out) * s_h_out) *
...@@ -449,7 +449,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad( ...@@ -449,7 +449,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad(
#define gaid_nhwc(N, H, W, C) \ #define gaid_nhwc(N, H, W, C) \
((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C)) ((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C))
int input_id; int input_id;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
input_id = ((bid * (gridDim.z / filter_multiplier) + input_id = ((bid * (gridDim.z / filter_multiplier) +
kernel_id / filter_multiplier) * kernel_id / filter_multiplier) *
input_height + input_height +
...@@ -528,19 +528,19 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T, ...@@ -528,19 +528,19 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
const DataLayout data_layout = DataLayout::kNCHW) { const DataLayout data_layout = DataLayout::kNCHW) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = const int input_channels =
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]); (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
const int input_height = const int input_height =
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]); (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
const int input_width = const int input_width =
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]); (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
const int output_channels = const int output_channels =
(data_layout == DataLayout::kNCHW ? output->dims()[1] (data_layout != DataLayout::kNHWC ? output->dims()[1]
: output->dims()[3]); : output->dims()[3]);
const int output_height = const int output_height =
(data_layout == DataLayout::kNCHW ? output->dims()[2] (data_layout != DataLayout::kNHWC ? output->dims()[2]
: output->dims()[1]); : output->dims()[1]);
const int output_width = const int output_width =
(data_layout == DataLayout::kNCHW ? output->dims()[3] (data_layout != DataLayout::kNHWC ? output->dims()[3]
: output->dims()[2]); : output->dims()[2]);
const int ksize_height = filter.dims()[2]; const int ksize_height = filter.dims()[2];
const int ksize_width = filter.dims()[3]; const int ksize_width = filter.dims()[3];
...@@ -614,19 +614,19 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T, ...@@ -614,19 +614,19 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
const DataLayout data_layout = DataLayout::kNCHW) { const DataLayout data_layout = DataLayout::kNCHW) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = const int input_channels =
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]); (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
const int input_height = const int input_height =
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]); (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
const int input_width = const int input_width =
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]); (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
const int output_channels = const int output_channels =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[1] (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
: output_grad.dims()[3]); : output_grad.dims()[3]);
const int output_height = const int output_height =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[2] (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
: output_grad.dims()[1]); : output_grad.dims()[1]);
const int output_width = const int output_width =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[3] (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
: output_grad.dims()[2]); : output_grad.dims()[2]);
const int ksize_height = filter.dims()[2]; const int ksize_height = filter.dims()[2];
const int ksize_width = filter.dims()[3]; const int ksize_width = filter.dims()[3];
...@@ -702,19 +702,19 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T, ...@@ -702,19 +702,19 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
const DataLayout data_layout = DataLayout::kNCHW) { const DataLayout data_layout = DataLayout::kNCHW) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_channels = const int input_channels =
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]); (data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
const int input_height = const int input_height =
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]); (data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
const int input_width = const int input_width =
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]); (data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
const int output_channels = const int output_channels =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[1] (data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
: output_grad.dims()[3]); : output_grad.dims()[3]);
const int output_height = const int output_height =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[2] (data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
: output_grad.dims()[1]); : output_grad.dims()[1]);
const int output_width = const int output_width =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[3] (data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
: output_grad.dims()[2]); : output_grad.dims()[2]);
const int ksize_height = filter_grad->dims()[2]; const int ksize_height = filter_grad->dims()[2];
const int ksize_width = filter_grad->dims()[3]; const int ksize_width = filter_grad->dims()[3];
......
...@@ -115,7 +115,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -115,7 +115,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
if ((im_row_idx) >= 0 && (im_row_idx) < im_height && if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) { (im_col_idx) >= 0 && (im_col_idx) < im_width) {
int im_offset; int im_offset;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
im_offset = im_offset =
(c_im * im_height + im_row_idx) * im_width + im_col_idx; (c_im * im_height + im_row_idx) * im_width + im_col_idx;
} else { } else {
......
...@@ -33,14 +33,14 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height, ...@@ -33,14 +33,14 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
const int index = const int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x; (blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < num_outs) { if (index < num_outs) {
int w_out = (data_layout == DataLayout::kNCHW int w_out = (data_layout != DataLayout::kNHWC
? index % col_width ? index % col_width
: (index / input_channels) % col_width); : (index / input_channels) % col_width);
int h_out = (data_layout == DataLayout::kNCHW int h_out = (data_layout != DataLayout::kNHWC
? (index / col_width) % col_height ? (index / col_width) % col_height
: (index / input_channels / col_width) % col_height); : (index / input_channels / col_width) % col_height);
int channel_in = int channel_in =
(data_layout == DataLayout::kNCHW ? index / col_width / col_height (data_layout != DataLayout::kNHWC ? index / col_width / col_height
: index % input_channels); : index % input_channels);
int channel_out = channel_in * filter_height * filter_width; int channel_out = channel_in * filter_height * filter_width;
int h_in = h_out * stride_height - padding_height; int h_in = h_out * stride_height - padding_height;
...@@ -52,7 +52,7 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height, ...@@ -52,7 +52,7 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
int rIdx = h_in + i * dilation_h; int rIdx = h_in + i * dilation_h;
int cIdx = w_in + j * dilation_w; int cIdx = w_in + j * dilation_w;
int im_idx; int im_idx;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
im_idx = (channel_in * im_height + rIdx) * im_width + cIdx; im_idx = (channel_in * im_height + rIdx) * im_width + cIdx;
} else { } else {
im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in; im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in;
...@@ -86,11 +86,11 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -86,11 +86,11 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
"The dimension of col should be 5."); "The dimension of col should be 5.");
int im_channels = int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]); (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
int im_height = int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]); (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
int im_width = int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]); (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1]; int filter_height = col->dims()[1];
int filter_width = col->dims()[2]; int filter_width = col->dims()[2];
int col_height = col->dims()[3]; int col_height = col->dims()[3];
...@@ -127,14 +127,14 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width, ...@@ -127,14 +127,14 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
if (index < n) { if (index < n) {
T val = 0; T val = 0;
int w = (data_layout == DataLayout::kNCHW int w = (data_layout != DataLayout::kNHWC
? index % im_width + padding_width ? index % im_width + padding_width
: (index / input_channels) % im_width + padding_width); : (index / input_channels) % im_width + padding_width);
int h = (data_layout == DataLayout::kNCHW int h = (data_layout != DataLayout::kNHWC
? (index / im_width) % im_height + padding_height ? (index / im_width) % im_height + padding_height
: (index / input_channels / im_width) % im_height + : (index / input_channels / im_width) % im_height +
padding_height); padding_height);
int c = (data_layout == DataLayout::kNCHW ? index / im_width / im_height int c = (data_layout != DataLayout::kNHWC ? index / im_width / im_height
: index % input_channels); : index % input_channels);
// compute the start and end of the output // compute the start and end of the output
...@@ -187,11 +187,11 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -187,11 +187,11 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
"The dimension of col should be 5."); "The dimension of col should be 5.");
int im_channels = int im_channels =
(data_layout == DataLayout::kNCHW ? im->dims()[0] : im->dims()[2]); (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
int im_height = int im_height =
(data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]); (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
int im_width = int im_width =
(data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]); (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
int filter_height = col.dims()[1]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int col_height = col.dims()[3]; int col_height = col.dims()[3];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册