提交 9107bf20 编写于 作者: A Adam 提交者: Tao Luo

Add template version of UpdatePadding (#21426)

test=develop
上级 ca879e5a
...@@ -64,17 +64,19 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation, ...@@ -64,17 +64,19 @@ inline int ConvOutputSize(int input_size, int filter_size, int dilation,
return output_size; return output_size;
} }
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilation, template <typename T = int>
inline void UpdatePaddingAndDilation(std::vector<T>* paddings,
std::vector<T>* dilation,
const std::string padding_algorithm, const std::string padding_algorithm,
const framework::DDim data_dims, const framework::DDim data_dims,
const std::vector<int>& strides, const std::vector<T>& strides,
const std::vector<int>& ksize) { const std::vector<T>& ksize) {
// set padding size == data_dims.size() * 2 // set padding size == data_dims.size() * 2
auto data_shape = framework::vectorize<int>(data_dims); auto data_shape = framework::vectorize<T>(data_dims);
if (static_cast<int>(paddings->size()) == data_dims.size()) { if (static_cast<int>(paddings->size()) == data_dims.size()) {
for (int i = 0; i < data_dims.size(); ++i) { for (int i = 0; i < data_dims.size(); ++i) {
int copy_pad = *(paddings->begin() + 2 * i); T copy_pad = *(paddings->begin() + 2 * i);
paddings->insert(paddings->begin() + 2 * i + 1, copy_pad); paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
} }
} else { } else {
...@@ -86,11 +88,11 @@ inline void UpdatePaddingAndDilation(std::vector<int>* paddings, ...@@ -86,11 +88,11 @@ inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
// when padding_algorithm is "VALID" or "SAME" // when padding_algorithm is "VALID" or "SAME"
if (padding_algorithm == "SAME") { if (padding_algorithm == "SAME") {
for (int i = 0; i < data_dims.size(); ++i) { for (int i = 0; i < data_dims.size(); ++i) {
int out_size = (data_dims[i] + strides[i] - 1) / strides[i]; T out_size = (data_dims[i] + strides[i] - 1) / strides[i];
int pad_sum = T pad_sum =
std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0); std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0);
int pad_0 = pad_sum / 2; T pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0; T pad_1 = pad_sum - pad_0;
*(paddings->begin() + i * 2) = pad_0; *(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1; *(paddings->begin() + i * 2 + 1) = pad_1;
......
...@@ -57,17 +57,19 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -57,17 +57,19 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
void Make() override; void Make() override;
}; };
inline void UpdatePadding(std::vector<int>* paddings, const bool global_pooling,
template <typename T = int>
inline void UpdatePadding(std::vector<T>* paddings, const bool global_pooling,
const bool adaptive, const bool adaptive,
const std::string padding_algorithm, const std::string padding_algorithm,
const framework::DDim data_dims, const framework::DDim data_dims,
const std::vector<int>& strides, const std::vector<T>& strides,
const std::vector<int>& ksize) { const std::vector<T>& ksize) {
// set padding size == data_dims.size() * 2 // set padding size == data_dims.size() * 2
auto data_shape = framework::vectorize<int>(data_dims); auto data_shape = framework::vectorize<T>(data_dims);
if (static_cast<int>(paddings->size()) == data_dims.size()) { if (static_cast<int>(paddings->size()) == data_dims.size()) {
for (int i = 0; i < data_dims.size(); ++i) { for (int i = 0; i < data_dims.size(); ++i) {
int copy_pad = *(paddings->begin() + 2 * i); T copy_pad = *(paddings->begin() + 2 * i);
paddings->insert(paddings->begin() + 2 * i + 1, copy_pad); paddings->insert(paddings->begin() + 2 * i + 1, copy_pad);
} }
} else { } else {
...@@ -79,11 +81,11 @@ inline void UpdatePadding(std::vector<int>* paddings, const bool global_pooling, ...@@ -79,11 +81,11 @@ inline void UpdatePadding(std::vector<int>* paddings, const bool global_pooling,
// when padding_algorithm is "VALID" or "SAME" // when padding_algorithm is "VALID" or "SAME"
if (padding_algorithm == "SAME") { if (padding_algorithm == "SAME") {
for (int i = 0; i < data_dims.size(); ++i) { for (int i = 0; i < data_dims.size(); ++i) {
int out_size = (data_dims[i] + strides[i] - 1) / strides[i]; T out_size = (data_dims[i] + strides[i] - 1) / strides[i];
int pad_sum = T pad_sum =
std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0); std::max((out_size - 1) * strides[i] + ksize[i] - data_shape[i], 0);
int pad_0 = pad_sum / 2; T pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0; T pad_1 = pad_sum - pad_0;
*(paddings->begin() + i * 2) = pad_0; *(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1; *(paddings->begin() + i * 2 + 1) = pad_1;
} }
...@@ -101,11 +103,12 @@ inline void UpdatePadding(std::vector<int>* paddings, const bool global_pooling, ...@@ -101,11 +103,12 @@ inline void UpdatePadding(std::vector<int>* paddings, const bool global_pooling,
} }
} }
inline void UpdateKsize(std::vector<int>* ksize, template <typename T = int>
inline void UpdateKsize(std::vector<T>* ksize,
const framework::DDim data_dims) { const framework::DDim data_dims) {
ksize->resize(static_cast<size_t>(data_dims.size())); ksize->resize(static_cast<size_t>(data_dims.size()));
for (size_t i = 0; i < ksize->size(); ++i) { for (size_t i = 0; i < ksize->size(); ++i) {
*(ksize->begin() + i) = static_cast<int>(data_dims[i]); *(ksize->begin() + i) = static_cast<T>(data_dims[i]);
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册