From 9107bf209faae5c0ecf8dd9f95de5fae3a290ea4 Mon Sep 17 00:00:00 2001 From: Adam <38704900+grygielski@users.noreply.github.com> Date: Fri, 29 Nov 2019 04:22:30 +0100 Subject: [PATCH] Add template version of UpdatePadding (#21426) test=develop --- paddle/fluid/operators/conv_op.h | 22 ++++++++++++---------- paddle/fluid/operators/pool_op.h | 25 ++++++++++++++----------- 2 files changed, 26 insertions(+), 21 deletions(-) diff --git a/paddle/fluid/operators/conv_op.h b/paddle/fluid/operators/conv_op.h index caf3a9a4bc..ad025c4683 100644 --- a/paddle/fluid/operators/conv_op.h +++ b/paddle/fluid/operators/conv_op.h @@ -64,17 +64,19 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation, return output_size; } -inline void UpdatePaddingAndDilation(std::vector* paddings, - std::vector* dilation, + +template +inline void UpdatePaddingAndDilation(std::vector* paddings, + std::vector* dilation, const std::string padding_algorithm, const framework::DDim data_dims, - const std::vector& strides, - const std::vector& ksize) { + const std::vector& strides, + const std::vector& ksize) { // set padding size == data_dims.size() * 2 - auto data_shape = framework::vectorize(data_dims); + auto data_shape = framework::vectorize(data_dims); if (static_cast(paddings->size()) == data_dims.size()) { for (int i = 0; i < data_dims.size(); ++i) { - int copy_pad = *(paddings->begin() + 2 * i); + T copy_pad = *(paddings->begin() + 2 * i); paddings->insert(paddings->begin() + 2 * i + 1, copy_pad); } } else { @@ -86,11 +88,11 @@ inline void UpdatePaddingAndDilation(std::vector* paddings, // when padding_algorithm is "VALID" or "SAME" if (padding_algorithm == "SAME") { for (int i = 0; i < data_dims.size(); ++i) { - int out_size = (data_dims[i] + strides[i] - 1) / strides[i]; - int pad_sum = + T out_size = (data_dims[i] + strides[i] - 1) / strides[i]; + T pad_sum = std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0); - int pad_0 = pad_sum / 2; - int pad_1 = pad_sum - pad_0; + T pad_0 = pad_sum / 2; + T pad_1 = pad_sum - pad_0; *(paddings->begin() + i * 2) = pad_0; *(paddings->begin() + i * 2 + 1) = pad_1; diff --git a/paddle/fluid/operators/pool_op.h b/paddle/fluid/operators/pool_op.h index 5db94f95b9..53551a8c50 100644 --- a/paddle/fluid/operators/pool_op.h +++ b/paddle/fluid/operators/pool_op.h @@ -57,17 +57,19 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override; }; -inline void UpdatePadding(std::vector* paddings, const bool global_pooling, + +template +inline void UpdatePadding(std::vector* paddings, const bool global_pooling, const bool adaptive, const std::string padding_algorithm, const framework::DDim data_dims, - const std::vector& strides, - const std::vector& ksize) { + const std::vector& strides, + const std::vector& ksize) { // set padding size == data_dims.size() * 2 - auto data_shape = framework::vectorize(data_dims); + auto data_shape = framework::vectorize(data_dims); if (static_cast(paddings->size()) == data_dims.size()) { for (int i = 0; i < data_dims.size(); ++i) { - int copy_pad = *(paddings->begin() + 2 * i); + T copy_pad = *(paddings->begin() + 2 * i); paddings->insert(paddings->begin() + 2 * i + 1, copy_pad); } } else { @@ -79,11 +81,11 @@ inline void UpdatePadding(std::vector* paddings, const bool global_pooling, // when padding_algorithm is "VALID" or "SAME" if (padding_algorithm == "SAME") { for (int i = 0; i < data_dims.size(); ++i) { - int out_size = (data_dims[i] + strides[i] - 1) / strides[i]; - int pad_sum = + T out_size = (data_dims[i] + strides[i] - 1) / strides[i]; + T pad_sum = std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0); - int pad_0 = pad_sum / 2; - int pad_1 = pad_sum - pad_0; + T pad_0 = pad_sum / 2; + T pad_1 = pad_sum - pad_0; *(paddings->begin() + i * 2) = pad_0; *(paddings->begin() + i * 2 + 1) = pad_1; } @@ -101,11 +103,12 @@ inline void UpdatePadding(std::vector* paddings, const bool global_pooling, } } -inline void UpdateKsize(std::vector* ksize, +template +inline void UpdateKsize(std::vector* ksize, const framework::DDim data_dims) { ksize->resize(static_cast(data_dims.size())); for (size_t i = 0; i < ksize->size(); ++i) { - *(ksize->begin() + i) = static_cast(data_dims[i]); + *(ksize->begin() + i) = static_cast(data_dims[i]); } } -- GitLab