未验证 提交 db2ab554 编写于 作者: H HappyAngel 提交者: GitHub

fix conv_3x3s1_dw v7-compute nan problem (#4309)

* fix conv_3x3s1_dw v7-compute nan. test=develop

* fix compute. tets=develop

* set sgemm basic_test is false. test=develop
上级 0108c64e
......@@ -620,8 +620,10 @@ void conv_depthwise_3x3_fp32(const void* din,
int pad = pad_w;
bool flag_bias = param.bias != nullptr;
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
bool ch_four = ch_in <= 4 * w_in;
if (stride == 1) {
if (pads_less && (pad_h == pad_w) && (pad < 2)) { // support pad = [0, 1]
if (ch_four && pads_less && (pad_h == pad_w) &&
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s1_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......@@ -638,7 +640,6 @@ void conv_depthwise_3x3_fp32(const void* din,
act_param,
ctx);
} else {
#ifdef __aarch64__
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......@@ -653,30 +654,10 @@ void conv_depthwise_3x3_fp32(const void* din,
param,
act_param,
ctx);
#else
#ifdef LITE_WITH_ARM_CLANG
LOG(FATAL) << "fp32 depthwise conv3x3s1px doesnot support in v7-clang, "
"this can run in basic";
#else
conv_3x3s1_depthwise_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
ch_out,
h_out,
w_out,
ch_in,
h_in,
w_in,
reinterpret_cast<const float*>(weights),
bias,
param,
act_param,
ctx);
#endif
#endif
}
} else if (stride == 2) {
if (pads_less && pad_h == pad_w && (pad < 2)) { // support pad = [0, 1]
if (ch_four && pads_less && pad_h == pad_w &&
(pad < 2)) { // support pad = [0, 1]
conv_depthwise_3x3s2_fp32(reinterpret_cast<const float*>(din),
reinterpret_cast<float*>(dout),
num,
......
......@@ -59,12 +59,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
bool flag_dw_3x3 = (kw == 3) && (kh == 3) && (stride == 1 || stride == 2);
bool flag_dw_5x5 = (kw == 5) && (kh == 5) && (stride == 1 || stride == 2);
#ifdef __aarch64__
#else
bool flag =
(stride == 1 && (paddings[0] > 1 || paddings[2] > 1)) ? false : true;
flag_dw_3x3 = flag_dw_3x3 && flag;
#endif
bool flag_dw = flag_dw_3x3 || flag_dw_5x5;
/// select conv impl
......
......@@ -28,11 +28,15 @@ void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
auto w_dims = param.filter->dims();
auto kw = w_dims[3];
auto channel = w_dims[0];
auto hin = param.x->dims()[2];
auto win = param.x->dims()[3];
auto paddings = *param.paddings;
bool ch_four = channel <= 4 * win;
// select dw conv kernel
if (kw == 3) {
bool pads_less = ((paddings[1] < 2) && (paddings[3] < 2));
if (pads_less && paddings[0] == paddings[2] &&
if (ch_four && pads_less && paddings[0] == paddings[2] &&
(paddings[0] == 0 || paddings[0] == 1)) {
flag_trans_weights_ = false;
} else {
......@@ -398,6 +402,14 @@ void DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
w_scale_.data());
}
#ifdef LITE_WITH_PROFILE
template <>
void DepthwiseConv<PRECISION(kFloat), PRECISION(kFloat)>::
SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
#endif
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -39,6 +39,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
#ifdef LITE_WITH_ARM
// sgemm_test wiil not be operated except that it's
// on arm backend.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册