提交 70d4809f 编写于 作者: Y yiicy 提交者: Xiaoyang LI

[cherry-pick][ARM] conv_transpose operator support padding_algorithm

上级 9ceb67bf
...@@ -32,8 +32,10 @@ void col2im<float>(const float* data_col, ...@@ -32,8 +32,10 @@ void col2im<float>(const float* data_col,
const int width, const int width,
const int kernel_h, const int kernel_h,
const int kernel_w, const int kernel_w,
const int pad_h, const int pad_h0,
const int pad_w, const int pad_h1,
const int pad_w0,
const int pad_w1,
const int stride_h, const int stride_h,
const int stride_w, const int stride_w,
const int dilation_h, const int dilation_h,
...@@ -41,19 +43,22 @@ void col2im<float>(const float* data_col, ...@@ -41,19 +43,22 @@ void col2im<float>(const float* data_col,
float* data_im) { float* data_im) {
memset(data_im, 0, height * width * channels * sizeof(float)); memset(data_im, 0, height * width * channels * sizeof(float));
const int output_h = const int output_h =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; (height + pad_h0 + pad_h1 - (dilation_h * (kernel_h - 1) + 1)) /
stride_h +
1;
const int output_w = const int output_w =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; (width + pad_w0 + pad_w1 - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
1;
const int channel_size = height * width; const int channel_size = height * width;
for (int channel = channels; channel--; data_im += channel_size) { for (int channel = channels; channel--; data_im += channel_size) {
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_h + kernel_row * dilation_h; int input_row = -pad_h0 + kernel_row * dilation_h;
for (int output_rows = output_h; output_rows; output_rows--) { for (int output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
data_col += output_w; data_col += output_w;
} else { } else {
int input_col = -pad_w + kernel_col * dilation_w; int input_col = -pad_w0 + kernel_col * dilation_w;
for (int output_col = output_w; output_col; output_col--) { for (int output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, width)) { if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
data_im[input_row * width + input_col] += *data_col; data_im[input_row * width + input_col] += *data_col;
......
...@@ -26,8 +26,10 @@ void col2im(const Dtype* data_col, ...@@ -26,8 +26,10 @@ void col2im(const Dtype* data_col,
const int width, const int width,
const int kernel_h, const int kernel_h,
const int kernel_w, const int kernel_w,
const int pad_h, const int pad_h0,
const int pad_w, const int pad_h1,
const int pad_w0,
const int pad_w1,
const int stride_h, const int stride_h,
const int stride_w, const int stride_w,
const int dilation_h, const int dilation_h,
......
...@@ -101,7 +101,6 @@ lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_ ...@@ -101,7 +101,6 @@ lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_
lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra) lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra)
lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm) lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm)
lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm) lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm)
lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm)
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA)
lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm) lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm)
lite_cc_test(test_lookup_table_compute_arm SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_arm) lite_cc_test(test_lookup_table_compute_arm SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_arm)
......
...@@ -96,7 +96,8 @@ void Conv2DTransposeCompute::Run() { ...@@ -96,7 +96,8 @@ void Conv2DTransposeCompute::Run() {
int group_size_weights = ((m_roundup * k + 15) / 16) * 16; int group_size_weights = ((m_roundup * k + 15) / 16) * 16;
bool flag_1x1s1p1 = (kw == 1) && (kh == 1) && (param.strides[0] == 1) && bool flag_1x1s1p1 = (kw == 1) && (kh == 1) && (param.strides[0] == 1) &&
(param.strides[1] == 1) && pads_all_qual && (param.strides[1] == 1) && pads_all_qual &&
(dilations[0] == 1) && (dilations[1] == 1); (paddings[0] == 0) && (dilations[0] == 1) &&
(dilations[1] == 1);
ctx.ExtendWorkspace(sizeof(float) * group * m * n); ctx.ExtendWorkspace(sizeof(float) * group * m * n);
auto din = param.x->data<float>(); auto din = param.x->data<float>();
...@@ -138,7 +139,9 @@ void Conv2DTransposeCompute::Run() { ...@@ -138,7 +139,9 @@ void Conv2DTransposeCompute::Run() {
kh, kh,
kw, kw,
paddings[0], paddings[0],
paddings[1],
paddings[2], paddings[2],
paddings[3],
param.strides[0], param.strides[0],
param.strides[1], param.strides[1],
dilations[0], dilations[0],
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/conv_transpose_compute.h"
#include <gtest/gtest.h>
#include <cmath>
#include <cstdlib>
#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
template <typename type, typename type2>
static void basic_gemm(int m,
int n,
int k,
const type* a,
const type* b,
const type2* bias,
type2* c,
type2 alpha,
type2 beta,
bool trans_a = false,
bool trans_b = false,
bool flag_bias = false,
bool flag_relu = false) {
#pragma omp parallel for
for (int i = 0; i < m; ++i) {
type2 bias_data = (type2)0;
if (flag_bias) {
bias_data = bias[i];
}
for (int j = 0; j < n; ++j) {
type2 sum = static_cast<type2>(0);
for (int l = 0; l < k; ++l) {
type av;
type bv;
if (trans_a) {
av = a[l * m + i];
} else {
av = a[i * k + l];
}
if (trans_b) {
bv = b[j * k + l];
} else {
bv = b[l * n + j];
}
sum += av * bv;
}
type2 tmp = alpha * sum + beta * c[i * n + j] + bias_data;
if (flag_relu) {
c[i * n + j] = tmp > (type2)0 ? tmp : (type2)0;
} else {
c[i * n + j] = tmp;
}
}
}
}
//! for float, dtype1 and type2 is float
//! for int8, dytpe1 is char, dtype2 is int
template <typename Dtype1, typename Dtype2>
bool deconv_basic(const Dtype1* din,
Dtype2* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const Dtype1* weights,
const Dtype2* bias,
int group,
int kernel_w,
int kernel_h,
int stride_w,
int stride_h,
int dila_w,
int dila_h,
int pad_w,
int pad_h,
bool flag_bias,
bool flag_relu) {
int m = chout * kernel_w * kernel_h / group;
int n = hin * win;
int k = chin / group;
if (chin != chout || group != chin) {
CHECK_OR_FALSE(chin % group == 0);
CHECK_OR_FALSE(chout % group == 0);
}
lite::Tensor workspace_tensor;
std::vector<int64_t> wt_shape = {1, 1, 1, group * m * n};
workspace_tensor.Resize(wt_shape);
auto* workspace_ptr = workspace_tensor.mutable_data<Dtype2>();
int group_size_in = win * hin * chin / group;
int group_size_out = wout * hout * chout / group;
int group_size_coldata = m * n;
int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group);
bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) &&
(stride_w == 1) && (pad_w == 1) && (pad_h == 1) &&
(dila_w == 1) && (dila_h == 1);
for (int i = 0; i < num; ++i) {
const Dtype1* din_batch = din + i * chin * hin * win;
Dtype2* dout_batch = dout + i * chout * hout * wout;
Dtype2* col_data = workspace_ptr;
if (flag_1x1s1p1) {
col_data = dout_batch;
}
memset(col_data, 0, sizeof(Dtype2) * group_size_coldata);
for (int g = 0; g < group; ++g) {
const Dtype1* din_group = din_batch + g * group_size_in;
const Dtype1* weights_group = weights + g * group_size_weights;
Dtype2* coldata_group = col_data + g * group_size_coldata;
basic_gemm<Dtype1, Dtype2>(m,
n,
k,
weights_group,
din_group,
nullptr,
coldata_group,
(Dtype2)1,
(Dtype2)0,
true,
false,
false,
(!flag_bias && flag_relu));
}
if (!flag_1x1s1p1) {
lite::arm::math::col2im(col_data,
chout,
hout,
wout,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w,
dila_h,
dila_w,
dout_batch);
}
if (flag_bias) {
lite::arm::math::fill_bias_relu(
dout_batch, bias, chout, wout * hout, flag_bias, flag_relu);
}
}
return true;
}
template <typename Dtype1, typename Dtype2>
void conv2d_transpose_compute_ref(const operators::ConvParam& param) {
const Dtype1* din = param.x->data<Dtype1>();
Dtype2* dout = param.output->mutable_data<Dtype2>();
int num = param.x->dims()[0];
int chout = param.output->dims()[1];
int hout = param.output->dims()[2];
int wout = param.output->dims()[3];
int chin = param.x->dims()[1];
int hin = param.x->dims()[2];
int win = param.x->dims()[3];
const Dtype1* weights = param.filter->mutable_data<Dtype1>();
Dtype2* bias = nullptr;
if (param.bias != nullptr) {
bias = param.bias->mutable_data<Dtype2>();
}
int group = param.groups;
auto paddings = *param.paddings;
auto dilations = *param.dilations;
int kernel_h = param.filter->dims()[2];
int kernel_w = param.filter->dims()[3];
int stride_h = param.strides[0];
int stride_w = param.strides[1];
int dila_h = dilations[0];
int dila_w = dilations[1];
int pad_h = paddings[0];
int pad_w = paddings[2];
bool flag_bias = (param.bias != nullptr);
bool flag_relu = param.fuse_relu;
deconv_basic<float, float>(din,
dout,
num,
chout,
hout,
wout,
chin,
hin,
win,
weights,
bias,
group,
kernel_w,
kernel_h,
stride_w,
stride_h,
dila_w,
dila_h,
pad_w,
pad_h,
flag_bias,
flag_relu);
}
TEST(conv2d_transpose_arm, retrive_op) {
auto op = KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"conv2d_transpose");
ASSERT_FALSE(op.empty());
ASSERT_TRUE(op.front());
}
TEST(conv2d_transpose_arm, init) {
Conv2DTransposeCompute compute;
ASSERT_EQ(compute.precision(), PRECISION(kFloat));
ASSERT_EQ(compute.target(), TARGET(kARM));
}
TEST(conv2d_transpose_arm, compute) {
DeviceInfo::Init();
for (auto n : {1, 2}) {
for (auto ic : {1, 3 /*, 128*/}) {
for (auto oc : {1, 3 /*, 128*/}) {
for (auto ih : {2, 8 /*, 56 , 112, 224, 512*/}) {
for (auto iw : {2, 8 /*, 56, 112, 224, 512*/}) {
for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) {
for (auto dilation : {1, 2}) {
for (auto stride : {1, 2}) {
for (auto padding : {0, 1, 2}) {
for (auto ks : {2, 3, 5}) {
for (auto group : {1, 2}) {
// obtain shape
if (ic % group != 0 || oc % group != 0) {
group = 1;
}
std::vector<int64_t> input_shape = {n, ic, ih, iw};
std::vector<int64_t> filter_shape = {
oc / group, ic, ks, ks};
int oh = (ih - 1) * stride - 2 * padding +
dilation * (ks - 1) + 1;
int ow = (iw - 1) * stride - 2 * padding +
dilation * (ks - 1) + 1;
if (oh < 1 || ow < 1) {
break;
}
std::vector<int64_t> output_shape = {n, oc, oh, ow};
std::vector<int64_t> bias_shape = {1, oc, 1, 1};
// define and resize tensor
Tensor input;
Tensor filter;
Tensor filter_copy;
Tensor bias;
Tensor output;
Tensor output_ref;
input.Resize(input_shape);
filter.Resize(filter_shape);
filter_copy.Resize(filter_shape);
output.Resize(output_shape);
output_ref.Resize(output_shape);
auto* input_data = input.mutable_data<float>();
auto* filter_data = filter.mutable_data<float>();
auto* filter_copy_data =
filter_copy.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
// initialize tensor
for (int i = 0; i < input.dims().production(); i++) {
float sign = i % 3 == 0 ? -1.0f : 1.0f;
input_data[i] = sign * static_cast<float>(i % 128);
}
for (int i = 0; i < filter.dims().production(); i++) {
filter_data[i] =
i /
static_cast<float>(filter.dims().production());
filter_copy_data[i] =
i / static_cast<float>(
filter_copy.dims().production());
}
if (flag_bias) {
bias.Resize(bias_shape);
auto* bias_data = bias.mutable_data<float>();
for (int i = 0; i < bias.dims().production(); i++) {
bias_data[i] = static_cast<float>(i);
}
}
// prepare kernel params and run
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<ARMContext>();
Conv2DTransposeCompute conv2d_transpose;
conv2d_transpose.SetContext(std::move(ctx));
operators::ConvParam param;
param.x = &input;
param.filter = &filter;
param.output = &output;
param.bias = nullptr;
if (flag_bias) {
bias.Resize(bias_shape);
auto* bias_data = bias.mutable_data<float>();
for (int i = 0; i < bias.dims().production(); i++) {
bias_data[i] = static_cast<float>(i);
}
param.bias = &bias;
}
param.fuse_relu = flag_relu;
std::vector<int> paddings = {
padding, padding, padding, padding};
param.strides = std::vector<int>({stride, stride});
std::vector<int> dilations = {dilation, dilation};
param.paddings =
std::make_shared<std::vector<int>>(paddings);
param.dilations =
std::make_shared<std::vector<int>>(dilations);
param.groups = group;
conv2d_transpose.SetParam(param);
conv2d_transpose.Launch();
// invoking ref implementation and compare results
param.filter = &filter_copy;
param.output = &output_ref;
conv2d_transpose_compute_ref<float, float>(param);
auto* output_ref_data =
output_ref.mutable_data<float>();
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(
output_data[i], output_ref_data[i], 1e-3);
}
}
}
}
}
}
}
}
}
}
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(conv2d_transpose, kARM, kFloat, kNCHW, def);
...@@ -34,24 +34,73 @@ bool ConvTransposeOpLite::CheckShape() const { ...@@ -34,24 +34,73 @@ bool ConvTransposeOpLite::CheckShape() const {
CHECK_OR_FALSE(in_dims.size() - param_.strides.size() == 2U); CHECK_OR_FALSE(in_dims.size() - param_.strides.size() == 2U);
CHECK_OR_FALSE(in_dims[1] % param_.groups == 0); CHECK_OR_FALSE(in_dims[1] % param_.groups == 0);
CHECK_EQ_OR_FALSE(filter_dims.size(), 4UL);
return true; return true;
} }
inline int ConvTransposeOutputSize(int input_size,
int filter_size,
int dilation,
int pad_left,
int pad_right,
int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
int output_size = (input_size - 1) * stride - pad_left - pad_right + dkernel;
return output_size;
}
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const lite::DDim data_dims,
const lite::DDim& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
bool ConvTransposeOpLite::InferShape() const { bool ConvTransposeOpLite::InferShape() const {
const auto in_dims = param_.x->dims(); const auto in_dims = param_.x->dims();
const auto filter_dims = param_.filter->dims(); const auto filter_dims = param_.filter->dims();
UpdatePaddingAndDilation(param_.paddings.get(),
param_.dilations.get(),
param_.strides,
padding_algorithm_,
in_dims,
filter_dims);
auto paddings = *param_.paddings; auto paddings = *param_.paddings;
auto dilations = *param_.dilations; auto dilations = *param_.dilations;
std::vector<int64_t> output_shape; std::vector<int64_t> output_shape;
output_shape.push_back(in_dims[0]); output_shape.push_back(in_dims[0]);
output_shape.push_back(filter_dims[1] * param_.groups); output_shape.push_back(filter_dims[1] * param_.groups);
for (int i = 0; i < param_.strides.size(); i++) { for (size_t i = 0; i < param_.strides.size(); ++i) {
int kernel_extent = dilations[i] * (filter_dims[i + 2] - 1) + 1; output_shape.push_back(ConvTransposeOutputSize(in_dims[i + 2],
int output_len = (in_dims[i + 2] - 1) * param_.strides[i] + kernel_extent - filter_dims[i + 2],
(paddings[2 * i] + paddings[2 * i + 1]); dilations[i],
output_shape.push_back(output_len); paddings[i * 2],
paddings[i * 2 + 1],
param_.strides[i]));
} }
// Set output dims // Set output dims
...@@ -60,8 +109,8 @@ bool ConvTransposeOpLite::InferShape() const { ...@@ -60,8 +109,8 @@ bool ConvTransposeOpLite::InferShape() const {
} }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc, bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope *scope) { lite::Scope* scope) {
auto X = op_desc.Input("Input").front(); auto X = op_desc.Input("Input").front();
auto Filter = op_desc.Input("Filter").front(); auto Filter = op_desc.Input("Filter").front();
auto Out = op_desc.Output("Output").front(); auto Out = op_desc.Output("Output").front();
...@@ -74,6 +123,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc, ...@@ -74,6 +123,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
param_.groups = op_desc.GetAttr<int>("groups"); param_.groups = op_desc.GetAttr<int>("groups");
auto dilations = op_desc.GetAttr<std::vector<int>>("dilations"); auto dilations = op_desc.GetAttr<std::vector<int>>("dilations");
if (op_desc.HasAttr("padding_algorithm")) {
padding_algorithm_ = op_desc.GetAttr<std::string>("padding_algorithm");
}
// 2-pad to 4-pad // 2-pad to 4-pad
if (paddings.size() == 2L) { if (paddings.size() == 2L) {
for (size_t i = 0; i < 2L; ++i) { for (size_t i = 0; i < 2L; ++i) {
...@@ -98,7 +150,7 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc, ...@@ -98,7 +150,7 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
auto bias_var = scope->FindVar(bias_arguments.front()); auto bias_var = scope->FindVar(bias_arguments.front());
if (bias_var != nullptr) { if (bias_var != nullptr) {
param_.bias = param_.bias =
const_cast<lite::Tensor *>(&(bias_var->Get<lite::Tensor>())); const_cast<lite::Tensor*>(&(bias_var->Get<lite::Tensor>()));
} }
} }
} }
......
...@@ -44,6 +44,7 @@ class ConvTransposeOpLite : public OpLite { ...@@ -44,6 +44,7 @@ class ConvTransposeOpLite : public OpLite {
private: private:
mutable ConvParam param_; mutable ConvParam param_;
std::string padding_algorithm_{""};
}; };
} // namespace operators } // namespace operators
......
...@@ -31,8 +31,10 @@ void col2im(const Dtype* data_col, ...@@ -31,8 +31,10 @@ void col2im(const Dtype* data_col,
const int width, const int width,
const int kernel_h, const int kernel_h,
const int kernel_w, const int kernel_w,
const int pad_h, const int pad_h0,
const int pad_w, const int pad_h1,
const int pad_w0,
const int pad_w1,
const int stride_h, const int stride_h,
const int stride_w, const int stride_w,
const int dilation_h, const int dilation_h,
...@@ -40,19 +42,22 @@ void col2im(const Dtype* data_col, ...@@ -40,19 +42,22 @@ void col2im(const Dtype* data_col,
Dtype* data_im) { Dtype* data_im) {
memset(data_im, 0, height * width * channels * sizeof(float)); memset(data_im, 0, height * width * channels * sizeof(float));
const int output_h = const int output_h =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; (height + pad_h0 + pad_h1 - (dilation_h * (kernel_h - 1) + 1)) /
stride_h +
1;
const int output_w = const int output_w =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; (width + pad_w0 + pad_w1 - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
1;
const int channel_size = height * width; const int channel_size = height * width;
for (int channel = channels; channel--; data_im += channel_size) { for (int channel = channels; channel--; data_im += channel_size) {
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_h + kernel_row * dilation_h; int input_row = -pad_h0 + kernel_row * dilation_h;
for (int output_rows = output_h; output_rows; output_rows--) { for (int output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
data_col += output_w; data_col += output_w;
} else { } else {
int input_col = -pad_w + kernel_col * dilation_w; int input_col = -pad_w0 + kernel_col * dilation_w;
for (int output_col = output_w; output_col; output_col--) { for (int output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, width)) { if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
data_im[input_row * width + input_col] += *data_col; data_im[input_row * width + input_col] += *data_col;
...@@ -104,6 +109,34 @@ void fill_bias_relu<float>(float* tensor, ...@@ -104,6 +109,34 @@ void fill_bias_relu<float>(float* tensor,
} }
} }
inline void UpdatePaddingAndDilation(std::vector<int>* paddings,
std::vector<int>* dilations,
const std::vector<int>& strides,
const std::string padding_algorithm,
const DDim data_dims,
const std::vector<int>& ksize) {
// when padding_desc is "VALID" or "SAME"
if (padding_algorithm == "SAME") {
for (size_t i = 0; i < strides.size(); ++i) {
int out_size = (data_dims[i + 2] + strides[i] - 1) / strides[i];
int pad_sum = std::max(
(out_size - 1) * strides[i] + ksize[i + 2] - data_dims[i + 2],
(int64_t)0);
int pad_0 = pad_sum / 2;
int pad_1 = pad_sum - pad_0;
// pad
*(paddings->begin() + i * 2) = pad_0;
*(paddings->begin() + i * 2 + 1) = pad_1;
// dilation
*(dilations->begin() + i) = 1;
}
} else if (padding_algorithm == "VALID") {
for (auto& it : *paddings) {
it = 0;
}
}
}
template <typename type, typename type2> template <typename type, typename type2>
static void basic_gemm(int m, static void basic_gemm(int m,
int n, int n,
...@@ -172,8 +205,10 @@ bool deconv_basic(const Dtype1* din, ...@@ -172,8 +205,10 @@ bool deconv_basic(const Dtype1* din,
int stride_h, int stride_h,
int dila_w, int dila_w,
int dila_h, int dila_h,
int pad_w, int pad_w0,
int pad_h, int pad_w1,
int pad_h0,
int pad_h1,
bool flag_bias, bool flag_bias,
bool flag_relu) { bool flag_relu) {
int m = chout * kernel_w * kernel_h / group; int m = chout * kernel_w * kernel_h / group;
...@@ -193,8 +228,9 @@ bool deconv_basic(const Dtype1* din, ...@@ -193,8 +228,9 @@ bool deconv_basic(const Dtype1* din,
int group_size_coldata = m * n; int group_size_coldata = m * n;
int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group); int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group);
bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) && bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) &&
(stride_w == 1) && (pad_w == 1) && (pad_h == 1) && (stride_w == 1) && (pad_w0 == 0) && (pad_h0 == 0) &&
(dila_w == 1) && (dila_h == 1); (pad_w1 == 0) && (pad_h1 == 0) && (dila_w == 1) &&
(dila_h == 1);
for (int i = 0; i < num; ++i) { for (int i = 0; i < num; ++i) {
const Dtype1* din_batch = din + i * chin * hin * win; const Dtype1* din_batch = din + i * chin * hin * win;
...@@ -204,7 +240,7 @@ bool deconv_basic(const Dtype1* din, ...@@ -204,7 +240,7 @@ bool deconv_basic(const Dtype1* din,
if (flag_1x1s1p1) { if (flag_1x1s1p1) {
col_data = dout_batch; col_data = dout_batch;
} }
memset(col_data, 0, sizeof(Dtype2) * group_size_coldata); memset(col_data, 0, sizeof(Dtype2) * group_size_coldata * group);
for (int g = 0; g < group; ++g) { for (int g = 0; g < group; ++g) {
const Dtype1* din_group = din_batch + g * group_size_in; const Dtype1* din_group = din_batch + g * group_size_in;
const Dtype1* weights_group = weights + g * group_size_weights; const Dtype1* weights_group = weights + g * group_size_weights;
...@@ -230,8 +266,10 @@ bool deconv_basic(const Dtype1* din, ...@@ -230,8 +266,10 @@ bool deconv_basic(const Dtype1* din,
wout, wout,
kernel_h, kernel_h,
kernel_w, kernel_w,
pad_h, pad_h0,
pad_w, pad_h1,
pad_w0,
pad_w1,
stride_h, stride_h,
stride_w, stride_w,
dila_h, dila_h,
...@@ -253,9 +291,10 @@ class Conv2DTransposeComputeTester : public arena::TestCase { ...@@ -253,9 +291,10 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
std::string output_ = "out"; std::string output_ = "out";
std::string filter_ = "filter"; std::string filter_ = "filter";
std::string bias_ = "bias"; std::string bias_ = "bias";
std::string padding_algorithm_ = "";
std::vector<int> strides_{1, 1}; std::vector<int> strides_{1, 1};
std::vector<int> paddings_{0, 0}; std::vector<int> paddings_{0, 0, 0, 0};
int groups_{1}; int groups_{1};
std::vector<int> dilations_{1, 1}; std::vector<int> dilations_{1, 1};
bool flag_relu_{false}; bool flag_relu_{false};
...@@ -280,9 +319,13 @@ class Conv2DTransposeComputeTester : public arena::TestCase { ...@@ -280,9 +319,13 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
bool flag_relu, bool flag_relu,
int dilation, int dilation,
int stride, int stride,
int padding, int pad_h0,
int pad_h1,
int pad_w0,
int pad_w1,
int ks, int ks,
int groups) int groups,
std::string padding_algorithm)
: TestCase(place, alias) { : TestCase(place, alias) {
n_ = n; n_ = n;
ic_ = ic; ic_ = ic;
...@@ -291,20 +334,29 @@ class Conv2DTransposeComputeTester : public arena::TestCase { ...@@ -291,20 +334,29 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
iw_ = iw; iw_ = iw;
ks_ = ks; ks_ = ks;
flag_bias_ = flag_bias; flag_bias_ = flag_bias;
padding_algorithm_ = padding_algorithm;
strides_ = std::vector<int>({stride, stride}); strides_ = std::vector<int>({stride, stride});
paddings_ = std::vector<int>({padding, padding}); paddings_ = std::vector<int>({pad_h0, pad_h1, pad_w0, pad_w1});
groups_ = groups;
dilations_ = std::vector<int>({dilation, dilation}); dilations_ = std::vector<int>({dilation, dilation});
groups_ = groups;
flag_relu_ = flag_relu; flag_relu_ = flag_relu;
} }
void RunBaseline(Scope* scope) override { void RunBaseline(Scope* scope) override {
auto* out = scope->NewTensor(output_); auto* out = scope->NewTensor(output_);
CHECK(out); CHECK(out);
int oh = (ih_ - 1) * strides_[0] - 2 * paddings_[0] + auto* x = scope->FindTensor(x_);
auto input_dim = x->dims();
std::vector<int> ksize({1, 1, ks_, ks_});
UpdatePaddingAndDilation(&paddings_,
&dilations_,
strides_,
padding_algorithm_,
input_dim,
ksize);
int oh = (ih_ - 1) * strides_[0] - paddings_[0] - paddings_[1] +
dilations_[0] * (ks_ - 1) + 1; dilations_[0] * (ks_ - 1) + 1;
int ow = (iw_ - 1) * strides_[1] - 2 * paddings_[1] + int ow = (iw_ - 1) * strides_[1] - paddings_[2] - paddings_[3] +
dilations_[1] * (ks_ - 1) + 1; dilations_[1] * (ks_ - 1) + 1;
CHECK(oh > 0 || ow > 0); CHECK(oh > 0 || ow > 0);
...@@ -313,7 +365,6 @@ class Conv2DTransposeComputeTester : public arena::TestCase { ...@@ -313,7 +365,6 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
out->Resize(output_dims); out->Resize(output_dims);
auto* output_data = out->mutable_data<float>(); auto* output_data = out->mutable_data<float>();
auto* x = scope->FindTensor(x_);
const auto* x_data = x->data<float>(); const auto* x_data = x->data<float>();
auto* filter = scope->FindTensor(filter_); auto* filter = scope->FindTensor(filter_);
const auto* filter_data = filter->data<float>(); const auto* filter_data = filter->data<float>();
...@@ -341,8 +392,10 @@ class Conv2DTransposeComputeTester : public arena::TestCase { ...@@ -341,8 +392,10 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
strides_[0], strides_[0],
dilations_[1], dilations_[1],
dilations_[0], dilations_[0],
paddings_[1], paddings_[2],
paddings_[3],
paddings_[0], paddings_[0],
paddings_[1],
flag_bias_, flag_bias_,
flag_relu_); flag_relu_);
} }
...@@ -360,6 +413,7 @@ class Conv2DTransposeComputeTester : public arena::TestCase { ...@@ -360,6 +413,7 @@ class Conv2DTransposeComputeTester : public arena::TestCase {
op_desc->SetInput("Bias", {bias_}); op_desc->SetInput("Bias", {bias_});
} }
op_desc->SetAttr("fuse_relu", flag_relu_); op_desc->SetAttr("fuse_relu", flag_relu_);
op_desc->SetAttr("padding_algorithm", padding_algorithm_);
} }
void PrepareData() override { void PrepareData() override {
...@@ -402,49 +456,66 @@ TEST(conv2d_transpose, precision) { ...@@ -402,49 +456,66 @@ TEST(conv2d_transpose, precision) {
LOG(INFO) << "test conv2d_transpose op"; LOG(INFO) << "test conv2d_transpose op";
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
Place place(TARGET(kARM)); Place place(TARGET(kARM));
for (auto n : {1, 2}) { for (auto n : {2}) {
for (auto ic : {1, 4 /*, 128*/}) { for (auto ic : {1, 4 /*, 128*/}) {
for (auto oc : {1, 4 /*, 128*/}) { for (auto oc : {1, 4 /*, 128*/}) {
LOG(INFO) << "n:" << n << ",ic:" << ic << ",oc:" << oc; LOG(INFO) << "n:" << n << ",ic:" << ic << ",oc:" << oc;
for (auto ih : {8, 16 /*, 56 , 112, 224, 512*/}) { for (auto ih : {8, 8 /*, 56 , 112, 224, 512*/}) {
for (auto iw : {8, 16 /*, 56, 112, 224, 512*/}) { for (auto iw : {8, 16 /*, 56, 112, 224, 512*/}) {
for (auto flag_bias : {false, true}) { for (auto flag_bias : {false, true}) {
for (auto flag_relu : {false, true}) { for (auto flag_relu : {false, true}) {
for (auto dilation : {1, 2}) { for (auto dilation : {1, 2}) {
for (auto stride : {1, 2}) { for (auto stride : {1, 2}) {
for (auto padding : {0, 2}) { for (auto pad_h0 : {0, 1}) {
for (auto ks : {2, 5}) { for (auto pad_h1 : {0, 1}) {
for (auto group : {1, 2}) { for (auto pad_w0 : {0, 1}) {
// obtain shape for (auto pad_w1 : {0, 1}) {
// LOG(INFO) << "n:" << n << ",ic:" << ic << ",oc:" << for (auto ks : {1, 4}) {
// oc for (auto group : {1, 2}) {
// << ",ih:" << ih << ",iw:" << iw for (auto padding_algorithm :
// << ",flag_bias:" << flag_bias {"", "SAME", "VALID"}) {
// << ",flag_relu:" << flag_relu // obtain shape
// << ",dila:" << dilation // LOG(INFO) << "n:" << n << ",ic:" << ic <<
// << ",stride:" << stride // ",oc:" <<
// << ",padding:" << padding << ",ks:" << ks // oc
// << ",group:" << group; // << ",ih:" << ih << ",iw:" << iw
if (ic % group != 0 || oc % group != 0) { // << ",flag_bias:" << flag_bias
group = 1; // << ",flag_relu:" << flag_relu
// << ",dila:" << dilation
// << ",stride:" << stride
// << ",padding:" << padding <<
// ",ks:" << ks
// << ",group:" << group;
if (ic % group != 0 || oc % group != 0) {
group = 1;
}
std::unique_ptr<arena::TestCase> tester(
new Conv2DTransposeComputeTester(
place,
"def",
n,
ic,
oc,
ih,
iw,
flag_bias,
flag_relu,
dilation,
stride,
pad_h0,
pad_h1,
pad_w0,
pad_w1,
ks,
group,
padding_algorithm));
arena::Arena arena(
std::move(tester), place, 2e-5);
arena.TestPrecision();
}
}
}
} }
std::unique_ptr<arena::TestCase> tester(
new Conv2DTransposeComputeTester(place,
"def",
n,
ic,
oc,
ih,
iw,
flag_bias,
flag_relu,
dilation,
stride,
padding,
ks,
group));
arena::Arena arena(std::move(tester), place, 2e-5);
arena.TestPrecision();
} }
} }
} }
......
...@@ -111,11 +111,11 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims, ...@@ -111,11 +111,11 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
param.output = new Tensor; param.output = new Tensor;
param.output->set_precision(PRECISION(kFloat)); param.output->set_precision(PRECISION(kFloat));
// paddle::lite::fill_tensor_rand(*param.filter, -1.f, 1.f); paddle::lite::fill_tensor_rand(*param.filter, -1.f, 1.f);
paddle::lite::fill_tensor_const(*param.filter, 1.f); // paddle::lite::fill_tensor_const(*param.filter, 1.f);
if (flag_bias) { if (flag_bias) {
// paddle::lite::fill_tensor_rand(*param.bias, -1.f, 1.f); paddle::lite::fill_tensor_rand(*param.bias, -1.f, 1.f);
paddle::lite::fill_tensor_const(*param.bias, 1.f); // paddle::lite::fill_tensor_const(*param.bias, 1.f);
} }
Tensor tmp_weights; Tensor tmp_weights;
tmp_weights.Resize(weight_dim); tmp_weights.Resize(weight_dim);
...@@ -130,21 +130,8 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims, ...@@ -130,21 +130,8 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
new paddle::lite::KernelContext); new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>(); auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), th); ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), th);
/// set param and context
for (auto& dim_in : input_dims) {
param.x->Resize(dim_in);
DDim out_tmp_dims = compute_out_dim(dim_in, param);
if (out_tmp_dims[2] < 1 || out_tmp_dims[3] < 1) {
continue;
}
param.output->Resize(out_tmp_dims);
break;
}
conv_t.SetParam(param); conv_t.SetParam(param);
conv_t.SetContext(std::move(ctx1)); conv_t.SetContext(std::move(ctx1));
/// prepare for run
conv_t.PrepareForRun();
for (auto& dim_in : input_dims) { for (auto& dim_in : input_dims) {
CHECK_EQ(weight_dim[0], dim_in[1]) CHECK_EQ(weight_dim[0], dim_in[1])
<< "input channel must equal to weights channel"; << "input channel must equal to weights channel";
...@@ -154,9 +141,11 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims, ...@@ -154,9 +141,11 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
} }
param.x->Resize(dim_in); param.x->Resize(dim_in);
param.output->Resize(dim_out); param.output->Resize(dim_out);
param.filter->CopyDataFrom(tmp_weights);
// paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f); // prepare for run
paddle::lite::fill_tensor_const(*param.x, 1.f); conv_t.PrepareForRun();
paddle::lite::fill_tensor_rand(*param.x, -1.f, 1.f);
// paddle::lite::fill_tensor_const(*param.x, 1.f);
auto din = param.x->data<float>(); auto din = param.x->data<float>();
Tensor tout_basic; Tensor tout_basic;
...@@ -185,7 +174,9 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims, ...@@ -185,7 +174,9 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
dilas[1], dilas[1],
dilas[0], dilas[0],
pads[2], pads[2],
pads[3],
pads[0], pads[0],
pads[1],
flag_bias, flag_bias,
flag_relu); flag_relu);
} }
...@@ -230,7 +221,8 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims, ...@@ -230,7 +221,8 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
LOG(FATAL) << "test fp32 conv: input: " << dim_in LOG(FATAL) << "test fp32 conv: input: " << dim_in
<< ", output: " << dim_out << ", output: " << dim_out
<< ", weight dim: " << weight_dim << ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1] << ", pad: " << pads[0] << ", " << pads[1] << ", "
<< pads[2] << ", " << pads[3]
<< ", stride: " << strides[0] << ", " << strides[1] << ", stride: " << strides[0] << ", " << strides[1]
<< ", dila_: " << dilas[0] << ", " << dilas[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false") << ", bias: " << (flag_bias ? "true" : "false")
...@@ -242,9 +234,9 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims, ...@@ -242,9 +234,9 @@ void test_conv_transpose_fp32(const std::vector<DDim>& input_dims,
} }
LOG(INFO) << "test fp32 conv: input: " << dim_in LOG(INFO) << "test fp32 conv: input: " << dim_in
<< ", output: " << dim_out << ", weight dim: " << weight_dim << ", output: " << dim_out << ", weight dim: " << weight_dim
<< ", pad: " << pads[0] << ", " << pads[1] << ", pad: " << pads[0] << ", " << pads[1] << ", " << pads[2]
<< ", stride: " << strides[0] << ", " << strides[1] << ", " << pads[3] << ", stride: " << strides[0] << ", "
<< ", dila_: " << dilas[0] << ", " << dilas[1] << strides[1] << ", dila_: " << dilas[0] << ", " << dilas[1]
<< ", bias: " << (flag_bias ? "true" : "false") << ", bias: " << (flag_bias ? "true" : "false")
<< ", relu: " << (flag_relu ? "true" : "false") << ", relu: " << (flag_relu ? "true" : "false")
<< ", threads: " << th << ", power_mode: " << cls << ", threads: " << th << ", power_mode: " << cls
...@@ -280,30 +272,37 @@ TEST(TestConvRand, test_conv_transpose_rand) { ...@@ -280,30 +272,37 @@ TEST(TestConvRand, test_conv_transpose_rand) {
for (auto& kw : {1, 2, 3}) { for (auto& kw : {1, 2, 3}) {
for (auto& kh : {1, 2, 3}) { for (auto& kh : {1, 2, 3}) {
for (auto& stride : {1, 2}) { for (auto& stride : {1, 2}) {
for (auto& pad : {0, 1, 2}) { for (auto& pad_h0 : {0, 1, 2}) {
for (auto& dila : {1, 2}) { for (auto& pad_h1 : {0, 1, 2}) {
for (auto& flag_bias : {false, true}) { for (auto& pad_w0 : {0, 1, 2}) {
for (auto& flag_relu : {false, true}) { for (auto& pad_w1 : {0, 1, 2}) {
if (cin % g != 0 || cout % g != 0) { for (auto& dila : {1, 2}) {
continue; for (auto& flag_bias : {false, true}) {
} for (auto& flag_relu : {false, true}) {
std::vector<DDim> dims; if (cin % g != 0 || cout % g != 0) {
DDim weights_dim({cin, cout / g, kh, kw}); continue;
for (auto& batch : {1, 2}) { }
for (auto& h : {1, 3, 19, 32, 28}) { std::vector<DDim> dims;
dims.push_back(DDim({batch, cin, h, h})); DDim weights_dim({cin, cout / g, kh, kw});
for (auto& batch : {1, 2}) {
for (auto& h : {1, 3, 19, 32, 28}) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
test_conv_transpose_fp32(
dims,
weights_dim,
g,
{stride, stride},
{pad_h0, pad_h1, pad_w0, pad_w1},
{dila, dila},
flag_bias,
flag_relu,
{1, 4},
{FLAGS_power_mode});
}
} }
} }
test_conv_transpose_fp32(dims,
weights_dim,
g,
{stride, stride},
{pad, pad, pad, pad},
{dila, dila},
flag_bias,
flag_relu,
{1, 2, 4},
{FLAGS_power_mode});
} }
} }
} }
......
...@@ -330,8 +330,10 @@ static void col2im(const Dtype* data_col, ...@@ -330,8 +330,10 @@ static void col2im(const Dtype* data_col,
const int width, const int width,
const int kernel_h, const int kernel_h,
const int kernel_w, const int kernel_w,
const int pad_h, const int pad_h0,
const int pad_w, const int pad_h1,
const int pad_w0,
const int pad_w1,
const int stride_h, const int stride_h,
const int stride_w, const int stride_w,
const int dilation_h, const int dilation_h,
...@@ -339,21 +341,24 @@ static void col2im(const Dtype* data_col, ...@@ -339,21 +341,24 @@ static void col2im(const Dtype* data_col,
Dtype* data_im) { Dtype* data_im) {
memset(data_im, 0, height * width * channels * sizeof(Dtype)); memset(data_im, 0, height * width * channels * sizeof(Dtype));
const int output_h = const int output_h =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; (height + pad_h0 + pad_h1 - (dilation_h * (kernel_h - 1) + 1)) /
stride_h +
1;
const int output_w = const int output_w =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; (width + pad_w0 + pad_w1 - (dilation_w * (kernel_w - 1) + 1)) / stride_w +
1;
const int channel_size = height * width; const int channel_size = height * width;
for (int channel = channels; channel--; data_im += channel_size) { for (int channel = channels; channel--; data_im += channel_size) {
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) { for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) { for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_h + kernel_row * dilation_h; int input_row = -pad_h0 + kernel_row * dilation_h;
for (int output_rows = output_h; output_rows; output_rows--) { for (int output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, height)) { if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
data_col += output_w; data_col += output_w;
} else { } else {
int input_col = -pad_w + kernel_col * dilation_w; int input_col = -pad_w0 + kernel_col * dilation_w;
for (int output_col = output_w; output_col; output_col--) { for (int output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, width)) { if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
...@@ -391,8 +396,10 @@ void deconv_basic(const Dtype1* din, ...@@ -391,8 +396,10 @@ void deconv_basic(const Dtype1* din,
int stride_h, int stride_h,
int dila_w, int dila_w,
int dila_h, int dila_h,
int pad_w, int pad_w0,
int pad_h, int pad_w1,
int pad_h0,
int pad_h1,
bool flag_bias, bool flag_bias,
bool flag_relu) { bool flag_relu) {
int m = chout * kernel_w * kernel_h / group; int m = chout * kernel_w * kernel_h / group;
...@@ -404,8 +411,9 @@ void deconv_basic(const Dtype1* din, ...@@ -404,8 +411,9 @@ void deconv_basic(const Dtype1* din,
int group_size_coldata = m * n; int group_size_coldata = m * n;
int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group); int group_size_weights = chin * chout * kernel_w * kernel_h / (group * group);
bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) && bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) &&
(stride_w == 1) && (pad_w == 1) && (pad_h == 1) && (stride_w == 1) && (pad_w0 == 0) && (pad_h0 == 0) &&
(dila_w == 1) && (dila_h == 1); (pad_w1 == 0) && (pad_h1 == 0) && (dila_w == 1) &&
(dila_h == 1);
Dtype2* workspace_ptr = Dtype2* workspace_ptr =
static_cast<Dtype2*>(malloc(sizeof(float) * m * n * group)); static_cast<Dtype2*>(malloc(sizeof(float) * m * n * group));
...@@ -418,7 +426,7 @@ void deconv_basic(const Dtype1* din, ...@@ -418,7 +426,7 @@ void deconv_basic(const Dtype1* din,
if (flag_1x1s1p1) { if (flag_1x1s1p1) {
col_data = dout_batch; col_data = dout_batch;
} }
memset(col_data, 0, sizeof(Dtype2) * group_size_coldata); memset(col_data, 0, sizeof(Dtype2) * group_size_coldata * group);
for (int g = 0; g < group; ++g) { for (int g = 0; g < group; ++g) {
const Dtype1* din_group = din_batch + g * group_size_in; const Dtype1* din_group = din_batch + g * group_size_in;
const Dtype1* weights_group = weights + g * group_size_weights; const Dtype1* weights_group = weights + g * group_size_weights;
...@@ -448,8 +456,10 @@ void deconv_basic(const Dtype1* din, ...@@ -448,8 +456,10 @@ void deconv_basic(const Dtype1* din,
wout, wout,
kernel_h, kernel_h,
kernel_w, kernel_w,
pad_h, pad_h0,
pad_w, pad_h1,
pad_w0,
pad_w1,
stride_h, stride_h,
stride_w, stride_w,
dila_h, dila_h,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册