提交 416b3417 编写于 作者: Z Zhang Ting 提交者: Aurelius84

[cherry-pick] fix the bug of conv_transpose: compitable with AnyLayout...

[cherry-pick] fix the bug of conv_transpose: compitable with AnyLayout setting, test=release/1.6 #(20897) (#20918)
上级 1948210c
......@@ -316,7 +316,7 @@ class CUDNNConvTransposeGradOpKernel : public framework::OpKernel<T> {
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
const std::string data_layout_str = ctx.Attr<std::string>("data_format");
const paddle::operators::DataLayout data_layout =
(data_layout_str == "NCHW" ? DataLayout::kNCHW : DataLayout::kNHWC);
(data_layout_str != "NHWC" ? DataLayout::kNCHW : DataLayout::kNHWC);
// if channel_last, transpose to channel_first
Tensor input_transpose;
......
......@@ -328,8 +328,10 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
col2vol(dev_ctx, col, dilations, strides, paddings, &out_slice,
data_layout);
}
if (data_layout == framework::DataLayout::kNHWC) {
output_batch_vec.push_back(out_slice);
}
}
if (data_layout == framework::DataLayout::kNHWC) {
concat_functor(dev_ctx, output_batch_vec, static_cast<int>(D - 2),
&output_batch);
......
......@@ -60,7 +60,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
const int w_in_end = w_in_start + filter_width * dilate_width;
int in_offset;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
in_offset =
((batch * input_channels + c_in) * input_height) * input_width;
} else {
......@@ -78,7 +78,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
if (h_in >= h_start && h_in < h_end && w_in >= w_start &&
w_in < w_end) {
int offset;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
offset = in_offset + h_in * input_width + w_in;
} else {
offset = in_offset +
......@@ -94,7 +94,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
}
}
int index;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
index = ((batch * gridDim.x + c_out) * output_height + h_out) *
output_width +
w_out;
......@@ -131,7 +131,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
const int w_in_end = w_in_start + c_filter * dilate_width;
int in_offset;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
in_offset =
((batch * input_channels + c_in) * input_height) * input_width;
} else {
......@@ -150,7 +150,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
w_in < input_width) {
int offset;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
offset = in_offset + h_in * input_width + w_in;
} else {
offset = in_offset +
......@@ -166,7 +166,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
}
}
int index;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
index = ((batch * gridDim.x + c_out) * output_height + h_out) *
output_width +
w_out;
......@@ -252,7 +252,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
T value = 0;
int index;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
......@@ -283,7 +283,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
s_w_out < output_width) {
int output_grad_offset;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
output_grad_offset =
((batch * output_channels + c_out) * output_height +
s_h_out) *
......@@ -335,7 +335,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
T value = 0;
int index;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
index =
((batch * gridDim.x + c_in) * input_height + h_in) * input_width +
w_in;
......@@ -363,7 +363,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
s_w_out < output_width) {
int output_grad_offset;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
output_grad_offset =
((batch * output_channels + c_out) * output_height +
s_h_out) *
......@@ -449,7 +449,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad(
#define gaid_nhwc(N, H, W, C) \
((((N)*output_height + (H)) * output_width + (W)) * gridDim.z + (C))
int input_id;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
input_id = ((bid * (gridDim.z / filter_multiplier) +
kernel_id / filter_multiplier) *
input_height +
......@@ -528,19 +528,19 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
const DataLayout data_layout = DataLayout::kNCHW) {
const int batch_size = input.dims()[0];
const int input_channels =
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
(data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
const int input_height =
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
(data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
const int input_width =
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
(data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
const int output_channels =
(data_layout == DataLayout::kNCHW ? output->dims()[1]
(data_layout != DataLayout::kNHWC ? output->dims()[1]
: output->dims()[3]);
const int output_height =
(data_layout == DataLayout::kNCHW ? output->dims()[2]
(data_layout != DataLayout::kNHWC ? output->dims()[2]
: output->dims()[1]);
const int output_width =
(data_layout == DataLayout::kNCHW ? output->dims()[3]
(data_layout != DataLayout::kNHWC ? output->dims()[3]
: output->dims()[2]);
const int ksize_height = filter.dims()[2];
const int ksize_width = filter.dims()[3];
......@@ -614,19 +614,19 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
const DataLayout data_layout = DataLayout::kNCHW) {
const int batch_size = input.dims()[0];
const int input_channels =
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
(data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
const int input_height =
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
(data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
const int input_width =
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
(data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
const int output_channels =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[1]
(data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
: output_grad.dims()[3]);
const int output_height =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[2]
(data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
: output_grad.dims()[1]);
const int output_width =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[3]
(data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
: output_grad.dims()[2]);
const int ksize_height = filter.dims()[2];
const int ksize_width = filter.dims()[3];
......@@ -702,19 +702,19 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
const DataLayout data_layout = DataLayout::kNCHW) {
const int batch_size = input.dims()[0];
const int input_channels =
(data_layout == DataLayout::kNCHW ? input.dims()[1] : input.dims()[3]);
(data_layout != DataLayout::kNHWC ? input.dims()[1] : input.dims()[3]);
const int input_height =
(data_layout == DataLayout::kNCHW ? input.dims()[2] : input.dims()[1]);
(data_layout != DataLayout::kNHWC ? input.dims()[2] : input.dims()[1]);
const int input_width =
(data_layout == DataLayout::kNCHW ? input.dims()[3] : input.dims()[2]);
(data_layout != DataLayout::kNHWC ? input.dims()[3] : input.dims()[2]);
const int output_channels =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[1]
(data_layout != DataLayout::kNHWC ? output_grad.dims()[1]
: output_grad.dims()[3]);
const int output_height =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[2]
(data_layout != DataLayout::kNHWC ? output_grad.dims()[2]
: output_grad.dims()[1]);
const int output_width =
(data_layout == DataLayout::kNCHW ? output_grad.dims()[3]
(data_layout != DataLayout::kNHWC ? output_grad.dims()[3]
: output_grad.dims()[2]);
const int ksize_height = filter_grad->dims()[2];
const int ksize_width = filter_grad->dims()[3];
......
......@@ -115,7 +115,7 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
if ((im_row_idx) >= 0 && (im_row_idx) < im_height &&
(im_col_idx) >= 0 && (im_col_idx) < im_width) {
int im_offset;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
im_offset =
(c_im * im_height + im_row_idx) * im_width + im_col_idx;
} else {
......
......@@ -33,14 +33,14 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
const int index =
(blockIdx.x * gridDim.y + blockIdx.y) * blockDim.x + threadIdx.x;
if (index < num_outs) {
int w_out = (data_layout == DataLayout::kNCHW
int w_out = (data_layout != DataLayout::kNHWC
? index % col_width
: (index / input_channels) % col_width);
int h_out = (data_layout == DataLayout::kNCHW
int h_out = (data_layout != DataLayout::kNHWC
? (index / col_width) % col_height
: (index / input_channels / col_width) % col_height);
int channel_in =
(data_layout == DataLayout::kNCHW ? index / col_width / col_height
(data_layout != DataLayout::kNHWC ? index / col_width / col_height
: index % input_channels);
int channel_out = channel_in * filter_height * filter_width;
int h_in = h_out * stride_height - padding_height;
......@@ -52,7 +52,7 @@ __global__ void im2col(const T* data_im, int num_outs, int im_height,
int rIdx = h_in + i * dilation_h;
int cIdx = w_in + j * dilation_w;
int im_idx;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
im_idx = (channel_in * im_height + rIdx) * im_width + cIdx;
} else {
im_idx = (rIdx * im_width + cIdx) * input_channels + channel_in;
......@@ -86,11 +86,11 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
"The dimension of col should be 5.");
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
(data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
(data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
(data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int col_height = col->dims()[3];
......@@ -127,14 +127,14 @@ __global__ void col2im(int n, const T* data_col, int im_height, int im_width,
if (index < n) {
T val = 0;
int w = (data_layout == DataLayout::kNCHW
int w = (data_layout != DataLayout::kNHWC
? index % im_width + padding_width
: (index / input_channels) % im_width + padding_width);
int h = (data_layout == DataLayout::kNCHW
int h = (data_layout != DataLayout::kNHWC
? (index / im_width) % im_height + padding_height
: (index / input_channels / im_width) % im_height +
padding_height);
int c = (data_layout == DataLayout::kNCHW ? index / im_width / im_height
int c = (data_layout != DataLayout::kNHWC ? index / im_width / im_height
: index % input_channels);
// compute the start and end of the output
......@@ -187,11 +187,11 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
"The dimension of col should be 5.");
int im_channels =
(data_layout == DataLayout::kNCHW ? im->dims()[0] : im->dims()[2]);
(data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]);
(data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]);
(data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int col_height = col.dims()[3];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册