From e5bf9c5670682a8931b8a94a7c683f3dae1193b4 Mon Sep 17 00:00:00 2001 From: chengduoZH Date: Tue, 21 Nov 2017 10:07:41 +0800 Subject: [PATCH] remove vector::eraze --- paddle/operators/conv_op.h | 54 ++++++++++++---------------- paddle/operators/conv_transpose_op.h | 46 +++++++++++------------- 2 files changed, 43 insertions(+), 57 deletions(-) diff --git a/paddle/operators/conv_op.h b/paddle/operators/conv_op.h index fac5f1d0e25..152d6b5132e 100644 --- a/paddle/operators/conv_op.h +++ b/paddle/operators/conv_op.h @@ -38,7 +38,7 @@ inline bool IsExpand(std::vector& filter_dim, std::vector& dilations) { bool filter_1 = true, strides_1 = true, padding_0 = true, dilation_1 = true; for (size_t j = 0; j < strides.size(); ++j) { - filter_1 = filter_1 && (static_cast(filter_dim[j]) == 1); + filter_1 = filter_1 && (static_cast(filter_dim[j + 2]) == 1); strides_1 = strides_1 && (strides[j] == 1); padding_0 = padding_0 && (paddings[j] == 0); dilation_1 = dilation_1 && (dilations[j] == 1); @@ -91,24 +91,20 @@ class GemmConvKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); - // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w} std::vector filter_shape_vec(framework::vectorize(filter.dims())); - filter_shape_vec.erase(filter_shape_vec.begin(), - filter_shape_vec.begin() + 2); - - // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w} + // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w} std::vector output_shape_vec(framework::vectorize(output->dims())); - output_shape_vec.erase(output_shape_vec.begin(), - output_shape_vec.begin() + 2); // use col_shape in the im2col calculation // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, // o_h, o_w} - std::vector col_shape_vec; - col_shape_vec.push_back(input->dims()[1] / groups); - col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + std::vector col_shape_vec(filter_shape_vec.size() + + output_shape_vec.size() - 3); + col_shape_vec.assign(1, input->dims()[1] / groups); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2, filter_shape_vec.end()); - col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(), + col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin() + 2, output_shape_vec.end()); framework::DDim col_shape(framework::make_ddim(col_shape_vec)); @@ -116,7 +112,7 @@ class GemmConvKernel : public framework::OpKernel { // size: (i_c/g * k_h * k_w, o_h * o_w) or (i_c/g * k_d * k_h * k_w, o_d * // o_h * o_w) framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1); bool is_expand = IsExpand(filter_shape_vec, strides, paddings, dilations); Tensor col; @@ -159,13 +155,13 @@ class GemmConvKernel : public framework::OpKernel { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - } else if (filter_shape_vec.size() == 2) { + } else if (filter_shape_vec.size() == 4) { // im2col im2col(context.device_context(), in_slice, dilations, strides, std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &col); - } else if (filter_shape_vec.size() == 3) { + } else if (filter_shape_vec.size() == 5) { // vol2col vol2col(context.device_context(), in_slice, dilations, strides, paddings, &col); @@ -206,25 +202,21 @@ class GemmConvGradKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); - // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + // filter_shape_vec: {k_o, k_i, k_h, k_w} or {k_o, k_i, k_d, k_h, k_w} std::vector filter_shape_vec(framework::vectorize(filter.dims())); - filter_shape_vec.erase(filter_shape_vec.begin(), - filter_shape_vec.begin() + 2); - - // output_shape_vec: {o_h, o_w} or {o_d, o_h, o_w} + // output_shape_vec: {o_n, o_c, o_h, o_w} or {o_n, o_c, o_d, o_h, o_w} std::vector output_shape_vec( framework::vectorize(output_grad->dims())); - output_shape_vec.erase(output_shape_vec.begin(), - output_shape_vec.begin() + 2); // use col_shape in the im2col calculation // col_shape_vec: {i_c/g, k_h, k_w, o_h, o_w} or {i_c/g, k_d, k_h, k_w, o_d, // o_h, o_w} - std::vector col_shape_vec; - col_shape_vec.push_back(input->dims()[1] / groups); - col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + std::vector col_shape_vec(filter_shape_vec.size() + + output_shape_vec.size() - 3); + col_shape_vec.assign(1, input->dims()[1] / groups); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2, filter_shape_vec.end()); - col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin(), + col_shape_vec.insert(col_shape_vec.end(), output_shape_vec.begin() + 2, output_shape_vec.end()); framework::DDim col_shape(framework::make_ddim(col_shape_vec)); @@ -233,7 +225,7 @@ class GemmConvGradKernel : public framework::OpKernel { // or // (i_c/g * k_d * k_h * k_w, o_d * o_h * o_w) framework::DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1); framework::DDim input_shape = framework::slice_ddim( input->dims(), 1, static_cast(input->dims().size())); @@ -294,12 +286,12 @@ class GemmConvGradKernel : public framework::OpKernel { out_grad_slice, false, T(1.0), &col_matrix, T(0.0)); - if (is_expand && filter_shape_vec.size() == 2) { + if (is_expand && filter_shape_vec.size() == 4) { col2im(context.device_context(), col, dilations, strides, std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &in_grad_slice); - } else if (is_expand && filter_shape_vec.size() == 3) { + } else if (is_expand && filter_shape_vec.size() == 5) { col2vol(context.device_context(), col, dilations, strides, paddings, &in_grad_slice); } @@ -328,12 +320,12 @@ class GemmConvGradKernel : public framework::OpKernel { col.ShareDataWith(in_slice); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); - } else if (filter_shape_vec.size() == 2) { + } else if (filter_shape_vec.size() == 4) { im2col(context.device_context(), in_slice, dilations, strides, std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &col); - } else if (filter_shape_vec.size() == 3) { + } else if (filter_shape_vec.size() == 5) { vol2col(context.device_context(), in_slice, dilations, strides, paddings, &col); } diff --git a/paddle/operators/conv_transpose_op.h b/paddle/operators/conv_transpose_op.h index ab336ad23ce..e9c953699e7 100644 --- a/paddle/operators/conv_transpose_op.h +++ b/paddle/operators/conv_transpose_op.h @@ -68,30 +68,27 @@ class GemmConvTransposeKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); - // input_shape_vec: {h, w} or {d, h, w} + // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} std::vector input_shape_vec = framework::vectorize(input->dims()); - input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2); - - // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + // filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w} std::vector filter_shape_vec = framework::vectorize(filter.dims()); - filter_shape_vec.erase(filter_shape_vec.begin(), - filter_shape_vec.begin() + 2); // use col_shape in the im2col and col2im (or vol2col and col2vol) // calculation // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} - std::vector col_shape_vec; - col_shape_vec.push_back(output->dims()[1]); - col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + std::vector col_shape_vec(filter_shape_vec.size() + + input_shape_vec.size() - 3); + col_shape_vec.assign(1, output->dims()[1]); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2, filter_shape_vec.end()); - col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(), + col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin() + 2, input_shape_vec.end()); DDim col_shape(framework::make_ddim(col_shape_vec)); // use col_matrix_shape in the gemm calculation // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1); Tensor col; col.mutable_data(col_shape, context.GetPlace()); @@ -136,7 +133,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { input_batch, false, static_cast(1.0), &col_matrix, static_cast(0.0)); - if (filter_shape_vec.size() == 2) { + if (filter_shape_vec.size() == 4) { // col2im: col_matrix -> dy // from (c * k_h * k_w, h * w) to (c, o_h, o_w) col2im(context.device_context(), col, @@ -144,7 +141,7 @@ class GemmConvTransposeKernel : public framework::OpKernel { std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &output_batch); - } else if (filter_shape_vec.size() == 3) { + } else if (filter_shape_vec.size() == 5) { // col2vol: col_matrix -> dy // from (c * k_d * k_h * k_w, d * h * w) to (c, o_d, o_h, o_w) col2vol(context.device_context(), col, dilations, strides, paddings, @@ -176,30 +173,27 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { const int batch_size = static_cast(input->dims()[0]); - // input_shape_vec: {h, w} or {d, h, w} + // input_shape_vec: {n, c, h, w} or {n, c, d, h, w} std::vector input_shape_vec = framework::vectorize(input->dims()); - input_shape_vec.erase(input_shape_vec.begin(), input_shape_vec.begin() + 2); - - // filter_shape_vec: {k_h, k_w} or {k_d, k_h, k_w} + // filter_shape_vec: {k_o, k_c, k_h, k_w} or {k_o, k_c, k_d, k_h, k_w} std::vector filter_shape_vec = framework::vectorize(filter.dims()); - filter_shape_vec.erase(filter_shape_vec.begin(), - filter_shape_vec.begin() + 2); // use col_shape in the im2col and col2im (or vol2col and col2vol) // calculation // col_shape_vec: {c, k_h, k_w, h, w} or {c, k_d, k_h, k_w, d, h, w} - std::vector col_shape_vec; - col_shape_vec.push_back(output_grad->dims()[1]); - col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin(), + std::vector col_shape_vec(filter_shape_vec.size() + + input_shape_vec.size() - 3); + col_shape_vec.assign(1, output_grad->dims()[1]); + col_shape_vec.insert(col_shape_vec.end(), filter_shape_vec.begin() + 2, filter_shape_vec.end()); - col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin(), + col_shape_vec.insert(col_shape_vec.end(), input_shape_vec.begin() + 2, input_shape_vec.end()); DDim col_shape(framework::make_ddim(col_shape_vec)); // use col_matrix_shape in the gemm calculation // size: (c * k_h * k_w, h * w) or (c * k_d * k_h * k_w, d * h * w) DDim col_matrix_shape = - framework::flatten_to_2d(col_shape, filter_shape_vec.size() + 1); + framework::flatten_to_2d(col_shape, filter_shape_vec.size() - 2 + 1); // output size: (c, o_h, o_w) or (c, o_d, o_h, o_w) DDim output_shape = framework::slice_ddim(output_grad->dims(), 1, @@ -248,7 +242,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { Tensor output_grad_batch = output_grad->Slice(i, i + 1).Resize(output_shape); - if (filter_shape_vec.size() == 2) { + if (filter_shape_vec.size() == 4) { // im2col: dy -> col matrix // from (c, o_h, o_w) to (c * k_h * k_w, h * w) im2col(context.device_context(), output_grad_batch, @@ -256,7 +250,7 @@ class GemmConvTransposeGradKernel : public framework::OpKernel { std::vector{paddings[0], paddings[1], paddings[0], paddings[1]}, &col); - } else if (filter_shape_vec.size() == 3) { + } else if (filter_shape_vec.size() == 5) { // vol2col: dy -> col_matrix // from (c, o_d, o_h, o_w) to (c * k_d * k_h * k_w, d * h * w) vol2col(context.device_context(), output_grad_batch, dilations, -- GitLab