未验证 提交 12a17969 编写于 作者: H HappyAngel 提交者: GitHub

【arm】fix pooling no-equal padding problem (#2956)

上级 294375f9
......@@ -92,7 +92,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
const operators::ActivationParam act_param,
ARMContext* ctx) {
if (pad == 0) {
if (w_in > 7) {
if (w_in > 8) {
conv_depthwise_3x3s2p0_bias(dout,
din,
weights,
......@@ -476,7 +476,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
\
"st1 {v16.4s}, [%[outptr0]], #16 \n" \
"fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
"fmul v12.4s, v16.4s, v22.4s \n" \
"fmul v12.4s, v17.4s, v22.4s \n" \
\
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
......@@ -552,6 +552,7 @@ void conv_depthwise_3x3s2_fp32(const float* din,
"ld1 {v20.4s}, [%[inptr3]] \n" \
"ld1 {v21.4s}, [%[inptr4]] \n" \
\
"fadd v17.4s, v17.4s, v14.4s \n" \
"bif v16.16b, v12.16b, v11.16b \n" /* choose*/ \
"ext v10.16b, v0.16b, v15.16b, #4 \n" \
"fcmge v11.4s, v17.4s, %[vzero].4s \n" /* vcgeq_u32 */ \
......
......@@ -113,9 +113,9 @@ namespace math {
"fcmge v7.4s, v22.4s, v0.4s \n" /* vcgeq_u32 */ \
"fmul v8.4s, v22.4s, %[vscale].4s \n" /* mul */ \
"bif v19.16b, v2.16b, v1.16b \n" /* choose*/ \
"bif v19.16b, v4.16b, v3.16b \n" /* choose*/ \
"bif v19.16b, v6.16b, v5.16b \n" /* choose*/ \
"bif v19.16b, v8.16b, v7.16b \n" /* choose*/
"bif v20.16b, v4.16b, v3.16b \n" /* choose*/ \
"bif v21.16b, v6.16b, v5.16b \n" /* choose*/ \
"bif v22.16b, v8.16b, v7.16b \n" /* choose*/
#define STORE /* save result */ \
"str q19, [%[outc0]], #16\n" \
"str q20, [%[outc1]], #16\n" \
......
......@@ -67,7 +67,6 @@ void pooling_basic(const float* din,
}
} else if (pooling_type == "avg") {
// Pooling_average_include_padding
// Pooling_average_exclude_padding
for (int n = 0; n < num; ++n) {
float* dout_batch = dout + n * chout * size_channel_out;
const float* din_batch = din + n * chin * size_channel_in;
......@@ -906,7 +905,9 @@ void pooling1x1s2p0_max(const float* din,
int wout,
int chin,
int hin,
int win) {
int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
......@@ -1021,7 +1022,9 @@ void pooling2x2s2_max(const float* din,
int wout,
int chin,
int hin,
int win) {
int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
......@@ -1104,7 +1107,9 @@ void pooling2x2s2_avg(const float* din,
int chin,
int hin,
int win,
bool exclusive) {
bool exclusive,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
......@@ -1117,6 +1122,9 @@ void pooling2x2s2_avg(const float* din,
int w_unroll_size = wout / 4;
int w_unroll_remian = wout - w_unroll_size * 4;
float32x4_t vcoef = vdupq_n_f32(0.25f); // divided by 4
auto zero_ptr =
static_cast<float*>(TargetMalloc(TARGET(kARM), win * sizeof(float)));
memset(zero_ptr, 0, win * sizeof(float));
for (int n = 0; n < num; ++n) {
float* data_out_batch = data_out + n * chout * size_channel_out;
......@@ -1132,7 +1140,7 @@ void pooling2x2s2_avg(const float* din,
auto dr0 = r0;
auto dr1 = r1;
if (h * S + K - P > hin) {
dr1 = r0;
dr1 = zero_ptr;
}
int cnt_num = w_unroll_size;
if (w_unroll_size > 0) {
......@@ -1178,6 +1186,7 @@ void pooling2x2s2_avg(const float* din,
}
}
}
TargetFree(TARGET(kARM), zero_ptr);
}
void pooling3x3s1p1_max(const float* din,
......@@ -1188,7 +1197,9 @@ void pooling3x3s1p1_max(const float* din,
int wout,
int chin,
int hin,
int win) {
int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
......@@ -1331,7 +1342,9 @@ void pooling3x3s1p1_avg(const float* din,
int chin,
int hin,
int win,
bool exclusive) {
bool exclusive,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
......@@ -1389,7 +1402,13 @@ void pooling3x3s1p1_avg(const float* din,
if (exclusive) {
coef_h = 1.f;
} else {
coef_h = 0.5f;
if (pad_bottom > 1) {
coef_h = 1.f / 3;
} else if (pad_bottom == 1) {
coef_h = 0.5f;
} else {
coef_h = 1.f;
}
}
break;
case 1:
......@@ -1401,7 +1420,11 @@ void pooling3x3s1p1_avg(const float* din,
coef_h = 0.5f;
}
} else {
coef_h = 1.f / 3;
if (pad_bottom >= 1) {
coef_h = 1.f / 3;
} else {
coef_h = 0.5f;
}
}
default:
break;
......@@ -1477,8 +1500,12 @@ void pooling3x3s1p1_avg(const float* din,
int st = wstart > 0 ? wstart : 0;
if (wstart + K > win) {
wend = win;
if (!exclusive && wstart + K - win == 2) {
coef = coef_h / 2;
if (!exclusive) {
if (wstart + K - pad_right - win == 1) {
coef = coef_h / 2;
} else if (wstart + K - pad_right - win == 2) {
coef = coef_h;
}
}
}
if (exclusive) {
......@@ -1509,7 +1536,9 @@ void pooling3x3s1p0_max(const float* din,
int wout,
int chin,
int hin,
int win) {
int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
......@@ -1646,7 +1675,9 @@ void pooling3x3s1p0_avg(const float* din,
int chin,
int hin,
int win,
bool exclusive) {
bool exclusive,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
......@@ -1692,7 +1723,13 @@ void pooling3x3s1p0_avg(const float* din,
if (exclusive) {
coef_h = 1.f;
} else {
coef_h = 0.5f;
if (pad_bottom > 1) {
coef_h = 1.f / 3;
} else if (pad_bottom = 1) {
coef_h = 0.5f;
} else {
coef_h = 1.f;
}
}
break;
case 1:
......@@ -1704,7 +1741,11 @@ void pooling3x3s1p0_avg(const float* din,
coef_h = 0.5f;
}
} else {
coef_h = 1.f / 3;
if (pad_bottom >= 1) {
coef_h = 1.f / 3;
} else {
coef_h = 0.5f;
}
}
default:
break;
......@@ -1776,8 +1817,12 @@ void pooling3x3s1p0_avg(const float* din,
int st = wstart > 0 ? wstart : 0;
if (wstart + K > win) {
wend = win;
if (!exclusive && wstart + K - win == 2) {
coef = coef_h / 2;
if (!exclusive) {
if (wstart + K - pad_right - win == 1) {
coef = coef_h / 2;
} else if (wstart + K - pad_right - win == 2) {
coef = coef_h;
}
}
}
if (exclusive) {
......@@ -1811,7 +1856,9 @@ void pooling3x3s2p1_max(const float* din,
int wout,
int chin,
int hin,
int win) {
int win,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
......@@ -1955,7 +2002,9 @@ void pooling3x3s2p1_avg(const float* din,
int chin,
int hin,
int win,
bool exclusive) {
bool exclusive,
int pad_bottom,
int pad_right) {
int size_channel_out = wout * hout;
int size_channel_in = win * hin;
auto data_out = static_cast<float*>(dout);
......@@ -2015,7 +2064,13 @@ void pooling3x3s2p1_avg(const float* din,
if (exclusive) {
coef_h = 1.f;
} else {
coef_h = 0.5f;
if (pad_bottom > 1) {
coef_h = 1.f / 3;
} else if (pad_bottom == 1) {
coef_h = 0.5f;
} else {
coef_h = 1.f;
}
}
break;
case 1:
......@@ -2027,7 +2082,11 @@ void pooling3x3s2p1_avg(const float* din,
coef_h = 0.5f;
}
} else {
coef_h = 1.f / 3;
if (pad_bottom == 0) {
coef_h = 1.f / 2;
} else {
coef_h = 1.f / 3;
}
}
default:
break;
......@@ -2102,8 +2161,12 @@ void pooling3x3s2p1_avg(const float* din,
float coef = coef_h / 3.f;
if (wstart + K > win) {
wend = win;
if (!exclusive && wstart + K - win == 2) {
coef = coef_h / 2;
if (!exclusive) {
if (wstart + K - pad_right - win == 1) {
coef = coef_h / 2;
} else if (wstart + K - pad_right - win == 2) {
coef = coef_h;
}
}
}
int st = wstart > 0 ? wstart : 0;
......@@ -2135,7 +2198,9 @@ void pooling3x3s2p0_max(const float* din,
int wout,
int chin,
int hin,
int win) {
int win,
int pad_bottom,
int pad_right) {
const int K = 3;
const int P = 0;
const int S = 2;
......@@ -2261,7 +2326,9 @@ void pooling3x3s2p0_avg(const float* din,
int chin,
int hin,
int win,
bool exclusive) {
bool exclusive,
int pad_bottom,
int pad_right) {
const int K = 3;
const int P = 0;
const int S = 2;
......@@ -2303,11 +2370,33 @@ void pooling3x3s2p0_avg(const float* din,
case 2:
dr1 = zero_ptr;
dr2 = zero_ptr;
coef_h = 1.f;
if (exclusive) {
coef_h = 1.f;
} else {
if (pad_bottom >= 2) {
coef_h = 1.f / 3;
} else if (pad_bottom == 1) {
coef_h = 0.5f;
} else {
coef_h = 1.0f;
}
}
break;
case 1:
dr2 = zero_ptr;
coef_h = 0.5f;
if (exclusive) {
if (fabsf(coef_h - 0.5f) < 1e-6f) {
coef_h = 1.f;
} else {
coef_h = 0.5f;
}
} else {
if (pad_bottom >= 1) {
coef_h = 1.0f / 3;
} else {
coef_h = 0.5f;
}
}
break;
default:
break;
......@@ -2366,22 +2455,34 @@ void pooling3x3s2p0_avg(const float* din,
dr2 -= 8;
}
// deal with right pad
int rem = win - (w_unroll_size * 4) * S;
int wstart = 0;
int wstart = w_unroll_size * 4 * S - P;
for (int j = 0; j < w_unroll_remian; ++j) {
int wend = std::min(wstart + K, rem);
float coef = coef_h / (wend - wstart);
int wend = wstart + K; // std::min(wstart + K, win);
float coef = coef_h / 3.f;
if (wstart + K > win) {
wend = win;
if (!exclusive) {
if (wstart + K - pad_right - win == 1) {
coef = coef_h / 2;
} else if (wstart + K - pad_right - win == 2) {
coef = coef_h;
}
}
}
int st = wstart > 0 ? wstart : 0;
if (exclusive) {
coef = coef_h / (wend - st);
}
float tmp = 0.f;
for (int i = wstart; i < wend; i++) {
tmp += dr0[i];
tmp += dr1[i];
tmp += dr2[i];
for (int i = 0; i < wend - st; i++) {
tmp += dr0[i] + dr1[i] + dr2[i];
}
tmp *= coef;
*(dr_out++) = tmp;
*(dr_out++) = tmp * coef;
dr0 += S - (st - wstart);
dr1 += S - (st - wstart);
dr2 += S - (st - wstart);
wstart += S;
}
r0 = r2;
r1 = r0 + win;
r2 = r1 + win;
......
......@@ -72,7 +72,9 @@ void pooling1x1s2p0_max(const float* din,
int wout,
int chin,
int hin,
int win);
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2_max(const float* din,
float* dout,
......@@ -82,7 +84,9 @@ void pooling2x2s2_max(const float* din,
int wout,
int chin,
int hin,
int win);
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2_avg(const float* din,
float* dout,
......@@ -93,7 +97,9 @@ void pooling2x2s2_avg(const float* din,
int chin,
int hin,
int win,
bool exclusive);
bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s1p1_max(const float* din,
float* dout,
......@@ -103,7 +109,9 @@ void pooling3x3s1p1_max(const float* din,
int wout,
int chin,
int hin,
int win);
int win,
int pad_bottom,
int pad_right);
void pooling3x3s1p1_avg(const float* din,
float* dout,
......@@ -114,7 +122,9 @@ void pooling3x3s1p1_avg(const float* din,
int chin,
int hin,
int win,
bool exclusive);
bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s2p1_max(const float* din,
float* dout,
......@@ -124,7 +134,9 @@ void pooling3x3s2p1_max(const float* din,
int wout,
int chin,
int hin,
int win);
int win,
int pad_bottom,
int pad_right);
void pooling3x3s1p0_max(const float* din,
float* dout,
......@@ -134,7 +146,9 @@ void pooling3x3s1p0_max(const float* din,
int wout,
int chin,
int hin,
int win);
int win,
int pad_bottom,
int pad_right);
void pooling3x3s1p0_avg(const float* din,
float* dout,
......@@ -145,7 +159,9 @@ void pooling3x3s1p0_avg(const float* din,
int chin,
int hin,
int win,
bool exclusive);
bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s2p1_avg(const float* din,
float* dout,
......@@ -156,7 +172,9 @@ void pooling3x3s2p1_avg(const float* din,
int chin,
int hin,
int win,
bool exclusive);
bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s2p0_max(const float* din,
float* dout,
......@@ -166,7 +184,9 @@ void pooling3x3s2p0_max(const float* din,
int wout,
int chin,
int hin,
int win);
int win,
int pad_bottom,
int pad_right);
void pooling3x3s2p0_avg(const float* din,
float* dout,
......@@ -177,7 +197,9 @@ void pooling3x3s2p0_avg(const float* din,
int chin,
int hin,
int win,
bool exclusive);
bool exclusive,
int pad_bottom,
int pad_right);
} // namespace math
} // namespace arm
......
......@@ -24,17 +24,28 @@ namespace mir {
void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> act_types{"relu"};
bool has_int8 = false;
bool has_arm_float = false;
bool has_cuda = false;
for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kCUDA)) {
act_types.push_back("leaky_relu");
break;
if (place.precision == PRECISION(kInt8)) {
has_int8 = true;
}
if (place.target == TARGET(kARM) && place.precision == PRECISION(kFloat)) {
act_types.push_back("relu6");
act_types.push_back("leaky_relu");
break;
has_arm_float = true;
}
if (place.target == TARGET(kCUDA)) {
has_cuda = true;
}
}
if (!has_int8 && has_arm_float) {
act_types.push_back("relu6");
act_types.push_back("leaky_relu");
}
if (!has_int8 && has_cuda) {
act_types.push_back("leaky_relu");
}
for (auto conv_type : {"conv2d", "depthwise_conv2d", "conv2d_transpose"}) {
for (auto act_type : act_types) {
for (auto has_bias : {true, false}) {
......
......@@ -53,14 +53,6 @@ void ConvActivationFuser::BuildPattern() {
void ConvActivationFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
// not fuse quantized conv2d + relu6 for now
auto conv2d_op_desc = matched.at("conv2d")->stmt()->op_info();
bool is_conv2d_quantized = conv2d_op_desc->HasAttr("enable_int8") &&
conv2d_op_desc->GetAttr<bool>("enable_int8");
if (act_type_ == "relu6" && is_conv2d_quantized) {
return;
}
auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op();
......
......@@ -47,13 +47,16 @@ void PoolCompute::Run() {
bool use_quantizer = param.use_quantizer;
std::string& data_format = param.data_format;
bool pads_equal = (paddings[0] == paddings[1]) &&
(paddings[2] == paddings[3]) &&
(paddings[0] == paddings[2]);
bool pads_less =
(paddings[0] == paddings[2]) && (paddings[1] < 2) && (paddings[3] < 2);
bool pads_equal = (paddings[0] == paddings[2]) &&
(paddings[0] == paddings[1]) &&
(paddings[2] == paddings[3]);
bool kps_equal =
(ksize[0] == ksize[1]) && (strides[0] == strides[1]) && pads_equal;
(ksize[0] == ksize[1]) && (strides[0] == strides[1]) && pads_less;
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && pads_equal;
(ksize[1] == in_dims[3]) && kps_equal && pads_equal;
global_pooling = param.global_pooling || global_pooling;
if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) {
......@@ -96,7 +99,9 @@ void PoolCompute::Run() {
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3]);
in_dims[3],
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
......@@ -110,7 +115,9 @@ void PoolCompute::Run() {
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3]);
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_avg(din,
......@@ -122,7 +129,9 @@ void PoolCompute::Run() {
in_dims[1],
in_dims[2],
in_dims[3],
exclusive);
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
......@@ -136,7 +145,9 @@ void PoolCompute::Run() {
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3]);
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s1p1_avg(din,
......@@ -148,7 +159,9 @@ void PoolCompute::Run() {
in_dims[1],
in_dims[2],
in_dims[3],
exclusive);
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
......@@ -162,7 +175,9 @@ void PoolCompute::Run() {
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3]);
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s1p0_avg(din,
......@@ -174,7 +189,9 @@ void PoolCompute::Run() {
in_dims[1],
in_dims[2],
in_dims[3],
exclusive);
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
......@@ -188,7 +205,9 @@ void PoolCompute::Run() {
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3]);
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p0_avg(din,
......@@ -200,7 +219,9 @@ void PoolCompute::Run() {
in_dims[1],
in_dims[2],
in_dims[3],
exclusive);
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
......@@ -214,7 +235,9 @@ void PoolCompute::Run() {
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3]);
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling3x3s2p1_avg(din,
......@@ -226,11 +249,14 @@ void PoolCompute::Run() {
in_dims[1],
in_dims[2],
in_dims[3],
exclusive);
exclusive,
paddings[1],
paddings[3]);
return;
}
}
}
lite::arm::math::pooling_basic(din,
dout,
out_dims[0],
......
......@@ -232,7 +232,7 @@ TEST(pool_arm, compute) {
lite::Tensor x;
lite::Tensor output;
lite::Tensor output_ref;
#if 0
// speedup for ci
for (auto pooling_type : {"max", "avg"}) {
for (auto ceil_mode : {true, false}) {
......@@ -337,6 +337,7 @@ TEST(pool_arm, compute) {
}
}
}
#endif
}
TEST(pool_arm, retrive_op) {
......
......@@ -34,7 +34,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册