提交 78910480 编写于 作者: Z Zhang Ting 提交者: Aurelius84

fix conv_transpose's bug: compatible with Anylayout setting, test=develop (#20589)

上级 172e91c0
...@@ -64,7 +64,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -64,7 +64,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"dimension should be the same."); "dimension should be the same.");
const int64_t C = const int64_t C =
(data_layout == DataLayout::kNCHW ? in_dims[1] (data_layout != DataLayout::kNHWC ? in_dims[1]
: in_dims[in_dims.size() - 1]); : in_dims[in_dims.size() - 1]);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
C, filter_dims[0], C, filter_dims[0],
...@@ -72,7 +72,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -72,7 +72,7 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
"be equal to the number of filter's channels."); "be equal to the number of filter's channels.");
framework::DDim in_data_dims; framework::DDim in_data_dims;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else { } else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
...@@ -84,10 +84,10 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -84,10 +84,10 @@ void ConvTransposeOp::InferShape(framework::InferShapeContext* ctx) const {
in_data_dims, strides, ksize); in_data_dims, strides, ksize);
std::vector<int64_t> output_shape({in_dims[0]}); std::vector<int64_t> output_shape({in_dims[0]});
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
output_shape.push_back(filter_dims[1] * groups); output_shape.push_back(filter_dims[1] * groups);
} }
const int offset = (data_layout == DataLayout::kNCHW ? 2 : 1); const int offset = (data_layout != DataLayout::kNHWC ? 2 : 1);
for (size_t i = 0; i < strides.size(); ++i) { for (size_t i = 0; i < strides.size(); ++i) {
auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1; auto filter_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1;
auto infer_shape = (in_dims[i + offset] - 1) * strides[i] - auto infer_shape = (in_dims[i + offset] - 1) * strides[i] -
......
...@@ -176,7 +176,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -176,7 +176,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
framework::DDim in_data_dims; framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else { } else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
...@@ -198,7 +198,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -198,7 +198,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w} // col_shape_vec: {o_c/g, k_h, k_w, h, w} or {o_c/g, k_d, k_h, k_w, d, h, w}
size_t data_dim = filter_shape_vec.size() - 2; size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
col_shape_vec[0] = out_dims[1] / groups; col_shape_vec[0] = out_dims[1] / groups;
for (size_t j = 0; j < data_dim; ++j) { for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2]; col_shape_vec[j + 1] = filter_shape_vec[j + 2];
...@@ -234,7 +234,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -234,7 +234,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
// input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
DDim input_matrix_shape; DDim input_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
input_matrix_shape = {in_dims[1], col_matrix_shape[1]}; input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
} else { } else {
input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]}; input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
...@@ -242,7 +242,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -242,7 +242,7 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
// filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w) // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
DDim filter_matrix_shape; DDim filter_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
filter_matrix_shape = {in_dims[1], col_matrix_shape[0]}; filter_matrix_shape = {in_dims[1], col_matrix_shape[0]};
} else { } else {
filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]}; filter_matrix_shape = {in_dims[in_dims.size() - 1], col_matrix_shape[0]};
...@@ -256,12 +256,12 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -256,12 +256,12 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
set_zero(dev_ctx, output, static_cast<T>(0)); set_zero(dev_ctx, output, static_cast<T>(0));
int in_step = int in_step =
(data_layout == framework::DataLayout::kNCHW (data_layout != framework::DataLayout::kNHWC
? static_cast<int>(in_dims[1]) / groups ? static_cast<int>(in_dims[1]) / groups
: static_cast<int>(in_dims[in_dims.size() - 1]) / groups); : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
int out_step = int out_step =
(data_layout == framework::DataLayout::kNCHW (data_layout != framework::DataLayout::kNHWC
? static_cast<int>(out_dims[1]) / groups ? static_cast<int>(out_dims[1]) / groups
: static_cast<int>(out_dims[out_dims.size() - 1]) / groups); : static_cast<int>(out_dims[out_dims.size() - 1]) / groups);
math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im; math::Col2ImFunctor<math::ColFormat::kCFO, DeviceContext, T> col2im;
...@@ -284,14 +284,14 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> { ...@@ -284,14 +284,14 @@ class GemmConvTransposeKernel : public framework::OpKernel<T> {
for (int g = 0; g < groups; g++) { for (int g = 0; g < groups; g++) {
int64_t start = g * in_step; int64_t start = g * in_step;
int64_t end = (g + 1) * in_step; int64_t end = (g + 1) * in_step;
int axes = (data_layout == framework::DataLayout::kNCHW ? 0 : 1); int axes = (data_layout != framework::DataLayout::kNHWC ? 0 : 1);
Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step); Tensor filter_slice = filter.Slice(g * in_step, (g + 1) * in_step);
Tensor in_slice, out_slice; Tensor in_slice, out_slice;
// col_matrix = filter_slice * input_slice // col_matrix = filter_slice * input_slice
// of shape (o_c/g * k_h * k_w, h * w) // of shape (o_c/g * k_h * k_w, h * w)
// or (o_c/g * k_d * k_h * k_w, d * h * w) // or (o_c/g * k_d * k_h * k_w, d * h * w)
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step); in_slice = input_batch.Slice(g * in_step, (g + 1) * in_step);
out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step); out_slice = output_batch.Slice(g * out_step, (g + 1) * out_step);
blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0), blas.MatMul(filter_slice, true, in_slice, false, static_cast<T>(1.0),
...@@ -372,7 +372,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -372,7 +372,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
const int batch_size = static_cast<int>(input->dims()[0]); const int batch_size = static_cast<int>(input->dims()[0]);
framework::DDim in_data_dims; framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else { } else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
...@@ -394,7 +394,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -394,7 +394,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for // col_shape_vec: {o_c, k_h, k_w, h, w} or {o_c, k_d, k_h, k_w, d, h, w} for
size_t data_dim = filter_shape_vec.size() - 2; size_t data_dim = filter_shape_vec.size() - 2;
std::vector<int64_t> col_shape_vec(1 + 2 * data_dim); std::vector<int64_t> col_shape_vec(1 + 2 * data_dim);
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
col_shape_vec[0] = out_grad_dims[1]; col_shape_vec[0] = out_grad_dims[1];
for (size_t j = 0; j < data_dim; ++j) { for (size_t j = 0; j < data_dim; ++j) {
col_shape_vec[j + 1] = filter_shape_vec[j + 2]; col_shape_vec[j + 1] = filter_shape_vec[j + 2];
...@@ -421,7 +421,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -421,7 +421,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first // input matrix size: (i_c, h * w) or (i_c, d * h * w) for channel_first
// input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last // input matrix size: (h * w, i_c) or (d * h * w, i_c) for channel_last
DDim input_matrix_shape; DDim input_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
input_matrix_shape = {in_dims[1], col_matrix_shape[1]}; input_matrix_shape = {in_dims[1], col_matrix_shape[1]};
} else { } else {
input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]}; input_matrix_shape = {col_matrix_shape[1], in_dims[in_dims.size() - 1]};
...@@ -429,7 +429,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -429,7 +429,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w) // filter size: (i_c, o_c/g * k_h * k_w) or (i_c, o_c/g * k_d * k_h * k_w)
DDim filter_matrix_shape; DDim filter_matrix_shape;
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups}; filter_matrix_shape = {in_dims[1], col_matrix_shape[0] / groups};
} else { } else {
filter_matrix_shape = {in_dims[in_dims.size() - 1], filter_matrix_shape = {in_dims[in_dims.size() - 1],
...@@ -438,7 +438,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -438,7 +438,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
filter.Resize(filter_matrix_shape); filter.Resize(filter_matrix_shape);
int in_step = int in_step =
(data_layout == framework::DataLayout::kNCHW (data_layout != framework::DataLayout::kNHWC
? static_cast<int>(in_dims[1]) / groups ? static_cast<int>(in_dims[1]) / groups
: static_cast<int>(in_dims[in_dims.size() - 1]) / groups); : static_cast<int>(in_dims[in_dims.size() - 1]) / groups);
int col_step = static_cast<int>(col_matrix_shape[0]) / groups; int col_step = static_cast<int>(col_matrix_shape[0]) / groups;
...@@ -531,7 +531,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -531,7 +531,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
// k_h * k_w, d * h * w) // k_h * k_w, d * h * w)
Tensor col_matrix_slice = Tensor col_matrix_slice =
col_matrix.Slice(g * col_step, (g + 1) * col_step); col_matrix.Slice(g * col_step, (g + 1) * col_step);
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
Tensor input_grad_slice = Tensor input_grad_slice =
input_grad_batch.Slice(g * in_step, (g + 1) * in_step); input_grad_batch.Slice(g * in_step, (g + 1) * in_step);
blas.MatMul(filter_slice, false, col_matrix_slice, false, blas.MatMul(filter_slice, false, col_matrix_slice, false,
...@@ -579,7 +579,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -579,7 +579,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
filter_grad_.Slice(g * in_step, (g + 1) * in_step); filter_grad_.Slice(g * in_step, (g + 1) * in_step);
Tensor col_matrix_slice = Tensor col_matrix_slice =
col_matrix.Slice(g * col_step, (g + 1) * col_step); col_matrix.Slice(g * col_step, (g + 1) * col_step);
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
Tensor in_batch_slice = Tensor in_batch_slice =
in_batch.Slice(g * in_step, (g + 1) * in_step); in_batch.Slice(g * in_step, (g + 1) * in_step);
blas.MatMul(in_batch_slice, false, col_matrix_slice, true, blas.MatMul(in_batch_slice, false, col_matrix_slice, true,
...@@ -629,7 +629,7 @@ class DepthwiseConvTransposeKernel : public framework::OpKernel<T> { ...@@ -629,7 +629,7 @@ class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
auto filter_dims = filter.dims(); auto filter_dims = filter.dims();
framework::DDim in_data_dims; framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else { } else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
...@@ -684,7 +684,7 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> { ...@@ -684,7 +684,7 @@ class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
auto filter_dims = filter.dims(); auto filter_dims = filter.dims();
framework::DDim in_data_dims; framework::DDim in_data_dims;
if (data_layout == framework::DataLayout::kNCHW) { if (data_layout != framework::DataLayout::kNHWC) {
in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size()); in_data_dims = framework::slice_ddim(in_dims, 2, in_dims.size());
} else { } else {
in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1); in_data_dims = framework::slice_ddim(in_dims, 1, in_dims.size() - 1);
......
...@@ -74,11 +74,11 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO, ...@@ -74,11 +74,11 @@ class Col2ImFunctor<paddle::operators::math::ColFormat::kCFO,
PADDLE_ENFORCE_EQ(col.dims().size(), 5, PADDLE_ENFORCE_EQ(col.dims().size(), 5,
"The dimension of col should be 5."); "The dimension of col should be 5.");
int im_channels = int im_channels =
(data_layout == DataLayout::kNCHW ? im->dims()[0] : im->dims()[2]); (data_layout != DataLayout::kNHWC ? im->dims()[0] : im->dims()[2]);
int im_height = int im_height =
(data_layout == DataLayout::kNCHW ? im->dims()[1] : im->dims()[0]); (data_layout != DataLayout::kNHWC ? im->dims()[1] : im->dims()[0]);
int im_width = int im_width =
(data_layout == DataLayout::kNCHW ? im->dims()[2] : im->dims()[1]); (data_layout != DataLayout::kNHWC ? im->dims()[2] : im->dims()[1]);
int filter_height = col.dims()[1]; int filter_height = col.dims()[1];
int filter_width = col.dims()[2]; int filter_width = col.dims()[2];
int col_height = col.dims()[3]; int col_height = col.dims()[3];
......
...@@ -33,11 +33,11 @@ inline void im2col_common(const framework::Tensor& im, ...@@ -33,11 +33,11 @@ inline void im2col_common(const framework::Tensor& im,
framework::Tensor* col, framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW) { const DataLayout data_layout = DataLayout::kNCHW) {
int im_channels = int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]); (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
int im_height = int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]); (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
int im_width = int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]); (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1]; int filter_height = col->dims()[1];
int filter_width = col->dims()[2]; int filter_width = col->dims()[2];
int output_height = col->dims()[3]; int output_height = col->dims()[3];
...@@ -55,7 +55,7 @@ inline void im2col_common(const framework::Tensor& im, ...@@ -55,7 +55,7 @@ inline void im2col_common(const framework::Tensor& im,
for (int w = 0; w < output_width; ++w) { for (int w = 0; w < output_width; ++w) {
int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1]; int im_col_idx = w * stride[1] - padding[1] + w_offset * dilation[1];
int im_idx; int im_idx;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx; im_idx = (im_row_idx + c_im * im_height) * im_width + im_col_idx;
} else { } else {
im_idx = (im_row_idx * im_width + im_col_idx) * im_channels + c_im; im_idx = (im_row_idx * im_width + im_col_idx) * im_channels + c_im;
...@@ -79,11 +79,11 @@ inline void im2col_sh1sw1dh1dw1ph0pw0( ...@@ -79,11 +79,11 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(
const framework::Tensor& im, framework::Tensor* col, const framework::Tensor& im, framework::Tensor* col,
const DataLayout data_layout = DataLayout::kNCHW) { const DataLayout data_layout = DataLayout::kNCHW) {
int im_channels = int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]); (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
int im_height = int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]); (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
int im_width = int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]); (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1]; int filter_height = col->dims()[1];
int filter_width = col->dims()[2]; int filter_width = col->dims()[2];
int output_height = col->dims()[3]; int output_height = col->dims()[3];
...@@ -103,7 +103,7 @@ inline void im2col_sh1sw1dh1dw1ph0pw0( ...@@ -103,7 +103,7 @@ inline void im2col_sh1sw1dh1dw1ph0pw0(
const T* src_data = src_data_ic; const T* src_data = src_data_ic;
for (int kh = 0; kh < filter_height; ++kh) { for (int kh = 0; kh < filter_height; ++kh) {
for (int kw = 0; kw < filter_width; ++kw) { for (int kw = 0; kw < filter_width; ++kw) {
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
std::memcpy(dst_data, src_data + kw, copy_size); std::memcpy(dst_data, src_data + kw, copy_size);
} else { } else {
for (int kow = 0; kow < output_width; ++kow) { for (int kow = 0; kow < output_width; ++kow) {
...@@ -131,11 +131,11 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, ...@@ -131,11 +131,11 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
framework::Tensor* col, framework::Tensor* col,
const DataLayout data_layout) { const DataLayout data_layout) {
int im_channels = int im_channels =
(data_layout == DataLayout::kNCHW ? im.dims()[0] : im.dims()[2]); (data_layout != DataLayout::kNHWC ? im.dims()[0] : im.dims()[2]);
int im_height = int im_height =
(data_layout == DataLayout::kNCHW ? im.dims()[1] : im.dims()[0]); (data_layout != DataLayout::kNHWC ? im.dims()[1] : im.dims()[0]);
int im_width = int im_width =
(data_layout == DataLayout::kNCHW ? im.dims()[2] : im.dims()[1]); (data_layout != DataLayout::kNHWC ? im.dims()[2] : im.dims()[1]);
int filter_height = col->dims()[1]; int filter_height = col->dims()[1];
int filter_width = col->dims()[2]; int filter_width = col->dims()[2];
int output_height = col->dims()[3]; int output_height = col->dims()[3];
...@@ -205,7 +205,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, ...@@ -205,7 +205,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
dst_data = dst_data + col_matrix_width; dst_data = dst_data + col_matrix_width;
continue; continue;
} }
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
std::memcpy(dst_data + plw, src_data, copy_size); std::memcpy(dst_data + plw, src_data, copy_size);
} else { } else {
for (int kow = 0; kow < output_width - plw - prw; ++kow) { for (int kow = 0; kow < output_width - plw - prw; ++kow) {
...@@ -261,7 +261,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, ...@@ -261,7 +261,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
// TODO(TJ): reuse plw-kw outside this for // TODO(TJ): reuse plw-kw outside this for
// try to unify // try to unify
for (int kw = 0; kw < plw; ++kw) { for (int kw = 0; kw < plw; ++kw) {
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
std::memcpy(dst_data + (plw - kw), src_data, std::memcpy(dst_data + (plw - kw), src_data,
sizeof(T) * (output_width - (plw - kw))); sizeof(T) * (output_width - (plw - kw)));
} else { } else {
...@@ -276,7 +276,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, ...@@ -276,7 +276,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
dst_data = dst_data + col_matrix_width; dst_data = dst_data + col_matrix_width;
} }
for (int kw = plw; kw < filter_width - prw; ++kw) { for (int kw = plw; kw < filter_width - prw; ++kw) {
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
std::memcpy(dst_data, src_data + (kw - plw), std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * output_width); sizeof(T) * output_width);
} else { } else {
...@@ -292,7 +292,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im, ...@@ -292,7 +292,7 @@ inline void im2col_sh1sw1dh1dw1ph1pw1(const framework::Tensor& im,
} }
int i = 1; int i = 1;
for (int kw = filter_width - prw; kw < filter_width; ++kw, ++i) { for (int kw = filter_width - prw; kw < filter_width; ++kw, ++i) {
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
std::memcpy(dst_data, src_data + (kw - plw), std::memcpy(dst_data, src_data + (kw - plw),
sizeof(T) * (output_width - i)); sizeof(T) * (output_width - i));
} else { } else {
......
...@@ -40,13 +40,13 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> { ...@@ -40,13 +40,13 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
"The dimension of col should be 7."); "The dimension of col should be 7.");
int input_channels = int input_channels =
(data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]); (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
int input_depth = int input_depth =
(data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]); (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
int input_height = int input_height =
(data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]); (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
int input_width = int input_width =
(data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]); (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1]; int filter_depth = col->dims()[1];
int filter_height = col->dims()[2]; int filter_height = col->dims()[2];
int filter_width = col->dims()[3]; int filter_width = col->dims()[3];
...@@ -104,7 +104,7 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> { ...@@ -104,7 +104,7 @@ class Vol2ColFunctor<platform::CPUDeviceContext, T> {
int col_idx = int col_idx =
((c * output_depth + d) * output_height + h) * output_width + w; ((c * output_depth + d) * output_height + h) * output_width + w;
int vol_idx; int vol_idx;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) * vol_idx = ((c_in * input_depth + d_pad) * input_height + h_pad) *
input_width + input_width +
w_pad; w_pad;
...@@ -146,13 +146,13 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> { ...@@ -146,13 +146,13 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
"The dimension of col should be 7."); "The dimension of col should be 7.");
int input_channels = int input_channels =
(data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]); (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
int input_depth = int input_depth =
(data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]); (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
int input_height = int input_height =
(data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]); (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
int input_width = int input_width =
(data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]); (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1]; int filter_depth = col.dims()[1];
int filter_height = col.dims()[2]; int filter_height = col.dims()[2];
int filter_width = col.dims()[3]; int filter_width = col.dims()[3];
...@@ -209,7 +209,7 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> { ...@@ -209,7 +209,7 @@ class Col2VolFunctor<platform::CPUDeviceContext, T> {
if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 && if (h_pad >= 0 && h_pad < input_height && w_pad >= 0 &&
w_pad < input_width && d_pad >= 0 && d_pad < input_depth) { w_pad < input_width && d_pad >= 0 && d_pad < input_depth) {
int vol_idx; int vol_idx;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) * vol_idx = ((cIm * input_depth + d_pad) * input_height + h_pad) *
input_width + input_width +
w_pad; w_pad;
......
...@@ -55,7 +55,7 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth, ...@@ -55,7 +55,7 @@ __global__ void vol2col(int num_kernels, const T* data_vol, int depth,
int h = h_in + i * dilation_h; int h = h_in + i * dilation_h;
int w = w_in + j * dilation_w; int w = w_in + j * dilation_w;
int vol_idx; int vol_idx;
if (data_layout == DataLayout::kNCHW) { if (data_layout != DataLayout::kNHWC) {
vol_idx = ((channel_in * depth + d) * height + h) * width + w; vol_idx = ((channel_in * depth + d) * height + h) * width + w;
} else { } else {
vol_idx = vol_idx =
...@@ -96,13 +96,13 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> { ...@@ -96,13 +96,13 @@ class Vol2ColFunctor<platform::CUDADeviceContext, T> {
"The dimension of col should be 7."); "The dimension of col should be 7.");
int input_channels = int input_channels =
(data_layout == DataLayout::kNCHW ? vol.dims()[0] : vol.dims()[3]); (data_layout != DataLayout::kNHWC ? vol.dims()[0] : vol.dims()[3]);
int input_depth = int input_depth =
(data_layout == DataLayout::kNCHW ? vol.dims()[1] : vol.dims()[0]); (data_layout != DataLayout::kNHWC ? vol.dims()[1] : vol.dims()[0]);
int input_height = int input_height =
(data_layout == DataLayout::kNCHW ? vol.dims()[2] : vol.dims()[1]); (data_layout != DataLayout::kNHWC ? vol.dims()[2] : vol.dims()[1]);
int input_width = int input_width =
(data_layout == DataLayout::kNCHW ? vol.dims()[3] : vol.dims()[2]); (data_layout != DataLayout::kNHWC ? vol.dims()[3] : vol.dims()[2]);
int filter_depth = col->dims()[1]; int filter_depth = col->dims()[1];
int filter_height = col->dims()[2]; int filter_height = col->dims()[2];
int filter_width = col->dims()[3]; int filter_width = col->dims()[3];
...@@ -170,16 +170,16 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth, ...@@ -170,16 +170,16 @@ __global__ void col2vol(int num_kernels, const T* data_col, int depth,
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < num_kernels;
index += blockDim.x * gridDim.x) { index += blockDim.x * gridDim.x) {
T src_val = 0; T src_val = 0;
int w = (data_layout == DataLayout::kNCHW int w = (data_layout != DataLayout::kNHWC
? index % width + padding_width ? index % width + padding_width
: (index / input_channels) % width + padding_width); : (index / input_channels) % width + padding_width);
int h = (data_layout == DataLayout::kNCHW int h = (data_layout != DataLayout::kNHWC
? (index / width) % height + padding_height ? (index / width) % height + padding_height
: (index / input_channels / width) % height + padding_height); : (index / input_channels / width) % height + padding_height);
int d = (data_layout == DataLayout::kNCHW int d = (data_layout != DataLayout::kNHWC
? (index / width / height) % depth + padding_depth ? (index / width / height) % depth + padding_depth
: index / input_channels / width / height + padding_depth); : index / input_channels / width / height + padding_depth);
int c = (data_layout == DataLayout::kNCHW ? index / width / height / depth int c = (data_layout != DataLayout::kNHWC ? index / width / height / depth
: index % input_channels); : index % input_channels);
// compute the start and end of the output // compute the start and end of the output
...@@ -247,13 +247,13 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> { ...@@ -247,13 +247,13 @@ class Col2VolFunctor<platform::CUDADeviceContext, T> {
"The dimension of col should be 7."); "The dimension of col should be 7.");
int input_channels = int input_channels =
(data_layout == DataLayout::kNCHW ? vol->dims()[0] : vol->dims()[3]); (data_layout != DataLayout::kNHWC ? vol->dims()[0] : vol->dims()[3]);
int input_depth = int input_depth =
(data_layout == DataLayout::kNCHW ? vol->dims()[1] : vol->dims()[0]); (data_layout != DataLayout::kNHWC ? vol->dims()[1] : vol->dims()[0]);
int input_height = int input_height =
(data_layout == DataLayout::kNCHW ? vol->dims()[2] : vol->dims()[1]); (data_layout != DataLayout::kNHWC ? vol->dims()[2] : vol->dims()[1]);
int input_width = int input_width =
(data_layout == DataLayout::kNCHW ? vol->dims()[3] : vol->dims()[2]); (data_layout != DataLayout::kNHWC ? vol->dims()[3] : vol->dims()[2]);
int filter_depth = col.dims()[1]; int filter_depth = col.dims()[1];
int filter_height = col.dims()[2]; int filter_height = col.dims()[2];
int filter_width = col.dims()[3]; int filter_width = col.dims()[3];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册