diff --git a/lite/backends/arm/math/softmax.cc b/lite/backends/arm/math/softmax.cc index 65d41b049123680f26674cc05d3c02172a260b31..b7f82e9f376e8b62195d884e8de19a142d76b316 100644 --- a/lite/backends/arm/math/softmax.cc +++ b/lite/backends/arm/math/softmax.cc @@ -531,7 +531,7 @@ void softmax_inner1_large_axis(const float* din, } float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax)); float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1)); - for (j = 4 * j; j < axis_size; ++j) { + for (j = 4 * nn; j < axis_size; ++j) { max_data = std::max(max_data, din_max_ptr[0]); din_max_ptr++; } @@ -557,7 +557,7 @@ void softmax_inner1_large_axis(const float* din, float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum)); float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1); - for (j = 4 * j; j < axis_size; ++j) { + for (j = 4 * nn; j < axis_size; ++j) { dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data); sum_data += dout_sum_ptr[0]; din_sum_ptr++; diff --git a/lite/kernels/arm/conv_compute.cc b/lite/kernels/arm/conv_compute.cc index ef174814ced73d4b2ec20580e06c63d39693ce57..140a8365fe4fb4b6fc0be55d7e3898ede9e7f73e 100644 --- a/lite/kernels/arm/conv_compute.cc +++ b/lite/kernels/arm/conv_compute.cc @@ -121,10 +121,14 @@ void ConvCompute::PrepareForRun() { no_dilation && flag_dw) { impl_ = new DepthwiseConv; // VLOG(3) << "Run DepthwiseConv Int8"; + } else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation && + pads_equal) { + impl_ = new DirectConv; + // VLOG(3) << "Run DirectConv Int8"; } else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation && pads_equal) { impl_ = new WinogradConv; - // VLOG(3) << "Run DirectConv Int8"; + // VLOG(3) << "Run WinogradConv Int8"; } else { impl_ = new GemmLikeConv; // VLOG(3) << "Run GemmLikeConvInt8"; @@ -168,10 +172,14 @@ void ConvCompute::PrepareForRun() { no_dilation && flag_dw) { impl_ = new DepthwiseConv; // VLOG(3) << "Run DepthwiseConv Int8"; + }else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation && + pads_equal) { + impl_ = new DirectConv; + // VLOG(3) << "Run DirectConv Int8"; } else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation && pads_equal) { impl_ = new WinogradConv; - // VLOG(3) << "Run DirectConv Int8"; + // VLOG(3) << "Run WinogradConv Int8"; } else { impl_ = new GemmLikeConv; // VLOG(3) << "Run GemmLikeConvInt8"; diff --git a/lite/kernels/arm/conv_winograd.cc b/lite/kernels/arm/conv_winograd.cc index c6e06a243cc1d1f1c8dc35338d8183352c4f679a..f61c6109cdfd57b30c2b57390d21dec7c3bb3aa2 100644 --- a/lite/kernels/arm/conv_winograd.cc +++ b/lite/kernels/arm/conv_winograd.cc @@ -358,6 +358,9 @@ void WinogradConv::Run() { param, &ctx); } +#ifdef LITE_WITH_PROFILE + kernel_func_name_ = "conv_compute_2x2_3x3_int8"; +#endif } template class WinogradConv; template class WinogradConv; diff --git a/lite/kernels/arm/conv_winograd.h b/lite/kernels/arm/conv_winograd.h index 69835a74b40b4f08d78cb11f3b9415eef7bc89d6..b93a719f7dbb13aa9888ea943fa81b6ea2b38c00 100644 --- a/lite/kernels/arm/conv_winograd.h +++ b/lite/kernels/arm/conv_winograd.h @@ -61,6 +61,13 @@ class WinogradConv virtual void PrepareForRun(); virtual void ReInitWhenNeeded(); virtual void Run(); +#ifdef LITE_WITH_PROFILE + virtual void SetProfileRuntimeKernelInfo( + paddle::lite::profile::OpCharacter* ch) { + ch->kernel_func_name = kernel_func_name_; + } + std::string kernel_func_name_{"NotImplForConvWino"}; +#endif protected: using param_t = operators::ConvParam;