提交 63e9b700 编写于 作者: Z Zhang Ting 提交者: lanxianghit

[cherry-pick] fix conv_transpose's bug: compatible with Anylayout setting

fix conv_transpose's bug: compatible with Anylayout setting: cherry-pick 20589
上级 837378b5
......@@ -64,7 +64,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"dimension should be the same.");
const int64_t C =
(data_layout == DataLayout::kNCHW ? in_dims[1]
(data_layout != DataLayout::kNHWC ? in_dims[1]
: in_dims[in_dims.size() - 1]);
PADDLE_ENFORCE_EQ(
C, filter_dims[0],
......@@ -72,7 +72,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"be equal to the number of filter's channels.");
framework::DDim in_data_dims;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
......@@ -84,10 +84,10 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]});
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
output_shape.push_back(filter_dims[1] * groups);
}
const int offset = (data_layout == DataLayout::kNCHW ? 2 : 1);
const int offset = (data_layout != DataLayout::kNHWC ? 2 : 1);
for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
auto infer_shape = (in_dims[i + offset] - 1) * strides[i] -
......
......@@ -176,7 +176,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
......@@ -198,7 +198,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w}
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
col_shape_vec[0] = out_dims[1] / groups;
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
......@@ -234,7 +234,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
// input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
DDim input_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
} else {
input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
......@@ -242,7 +242,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
DDim filter_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
filter_matrix_shape = {in_dims[1], col_matrix_shape[0]};
} else {
filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]};
......@@ -256,12 +256,12 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
set_zero(dev_ctx, output, static_cast<T>(0));
int in_step =
(data_layout == framework::DataLayout::kNCHW
(data_layout != framework::DataLayout::kNHWC
? static_cast<int>(in_dims[1]) / groups
: static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
int out_step =
(data_layout == framework::DataLayout::kNCHW
(data_layout != framework::DataLayout::kNHWC
? static_cast<int>(out_dims[1]) / groups
: static_cast<int>(out_dims[out_dims.size() - 1]) / groups);
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
......@@ -284,14 +284,14 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
for (int g = 0; g < groups; g++) {
int64_t start = g * in_step;
int64_t end = (g + 1) * in_step;
int axes = (data_layout == framework::DataLayout::kNCHW ? 0 : 1);
int axes = (data_layout != framework::DataLayout::kNHWC ? 0 : 1);
Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
Tensor in_slice, out_slice;
// col_matrix = filter_slice * input_slice
// of shape (o_c/g * k_h * k_w, h * w)
// or (o_c/g * k_d * k_h * k_w, d * h * w)
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
......@@ -372,7 +372,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]);
framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
......@@ -394,7 +394,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for
size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
col_shape_vec[0] = out_grad_dims[1];
for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2];
......@@ -421,7 +421,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
// input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
DDim input_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
} else {
input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
......@@ -429,7 +429,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
DDim filter_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups};
} else {
filter_matrix_shape = {in_dims[in_dims.size() - 1],
......@@ -438,7 +438,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
filter.Resize(filter_matrix_shape);
int in_step =
(data_layout == framework::DataLayout::kNCHW
(data_layout != framework::DataLayout::kNHWC
? static_cast<int>(in_dims[1]) / groups
: static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
int col_step = static_cast<int>(col_matrix_shape[0]) / groups;
......@@ -531,7 +531,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// k_h * k_w, d * h * w)
Tensor col_matrix_slice =
col_matrix.Slice(g * col_step, (g + 1) * col_step);
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
Tensor input_grad_slice =
input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
blas.MatMul(filter_slice, false, col_matrix_slice, false,
......@@ -579,7 +579,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
filter_grad_.Slice(g * in_step, (g + 1) * in_step);
Tensor col_matrix_slice =
col_matrix.Slice(g * col_step, (g + 1) * col_step);
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
Tensor in_batch_slice =
in_batch.Slice(g * in_step, (g + 1) * in_step);
blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
......@@ -629,7 +629,7 @@ class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
......@@ -684,7 +684,7 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
auto filter_dims = filter.dims();
framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) {
if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
......
......@@ -74,11 +74,11 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5.");
int im_channels =
(data_layout == DataLayout::kNCHW ? im->dims()[0] : im->dims()[2]);
(data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]);
(data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]);
(data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
int filter_height = col.dims()[1];
int filter_width = col.dims()[2];
int col_height = col.dims()[3];
......
......@@ -33,11 +33,11 @@ inline void im2col_common(const framework::Tensor& im,
framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW) {
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
(data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
(data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
(data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int output_height = col->dims()[3];
......@@ -55,7 +55,7 @@ inline void im2col_common(const framework::Tensor& im,
for (int w = 0; w < output_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int im_idx;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
} else {
im_idx = (im_row_idx * im_width + im_col_idx) * im_channels + c_im;
......@@ -79,11 +79,11 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(
const framework::Tensor& im, framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW) {
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
(data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
(data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
(data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int output_height = col->dims()[3];
......@@ -103,7 +103,7 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(
const T* src_data = src_data_ic;
for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) {
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
std::memcpy(dst_data, src_data + kw, copy_size);
} else {
for (int kow = 0; kow < output_width; ++kow) {
......@@ -131,11 +131,11 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
framework::Tensor* col,
const DataLayout data_layout) {
int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]);
(data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]);
(data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]);
(data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1];
int filter_width = col->dims()[2];
int output_height = col->dims()[3];
......@@ -205,7 +205,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
dst_data = dst_data + col_matrix_width;
continue;
}
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
std::memcpy(dst_data + plw, src_data, copy_size);
} else {
for (int kow = 0; kow < output_width - plw - prw; ++kow) {
......@@ -261,7 +261,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
// TODO(TJ): reuse plw-kw outside this for
// try to unify
for (int kw = 0; kw < plw; ++kw) {
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
std::memcpy(dst_data + (plw - kw), src_data,
sizeof(T) * (output_width - (plw - kw)));
} else {
......@@ -276,7 +276,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
dst_data = dst_data + col_matrix_width;
}
for (int kw = plw; kw < filter_width - prw; ++kw) {
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * output_width);
} else {
......@@ -292,7 +292,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
}
int i = 1;
for (int kw = filter_width - prw; kw < filter_width; ++kw, ++i) {
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * (output_width - i));
} else {
......
......@@ -40,13 +40,13 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
"The dimension of col should be 7.");
int input_channels =
(data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]);
(data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]);
(data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]);
(data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]);
(data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
......@@ -104,7 +104,7 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
int col_idx =
((c * output_depth + d) * output_height + h) * output_width + w;
int vol_idx;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
......@@ -146,13 +146,13 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
"The dimension of col should be 7.");
int input_channels =
(data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]);
(data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]);
(data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]);
(data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]);
(data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
......@@ -209,7 +209,7 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
int vol_idx;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) *
input_width +
w_pad;
......
......@@ -55,7 +55,7 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
int h = h_in + i * dilation_h;
int w = w_in + j * dilation_w;
int vol_idx;
if (data_layout == DataLayout::kNCHW) {
if (data_layout != DataLayout::kNHWC) {
vol_idx = ((channel_in * depth + d) * height + h) * width + w;
} else {
vol_idx =
......@@ -96,13 +96,13 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
"The dimension of col should be 7.");
int input_channels =
(data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]);
(data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]);
(data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]);
(data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]);
(data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1];
int filter_height = col->dims()[2];
int filter_width = col->dims()[3];
......@@ -170,16 +170,16 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) {
T src_val = 0;
int w = (data_layout == DataLayout::kNCHW
int w = (data_layout != DataLayout::kNHWC
? index % width + padding_width
: (index / input_channels) % width + padding_width);
int h = (data_layout == DataLayout::kNCHW
int h = (data_layout != DataLayout::kNHWC
? (index / width) % height + padding_height
: (index / input_channels / width) % height + padding_height);
int d = (data_layout == DataLayout::kNCHW
int d = (data_layout != DataLayout::kNHWC
? (index / width / height) % depth + padding_depth
: index / input_channels / width / height + padding_depth);
int c = (data_layout == DataLayout::kNCHW ? index / width / height / depth
int c = (data_layout != DataLayout::kNHWC ? index / width / height / depth
: index % input_channels);
// compute the start and end of the output
......@@ -247,13 +247,13 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
"The dimension of col should be 7.");
int input_channels =
(data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]);
(data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
int input_depth =
(data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]);
(data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
int input_height =
(data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]);
(data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
int input_width =
(data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]);
(data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1];
int filter_height = col.dims()[2];
int filter_width = col.dims()[3];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册