提交 ee1c6d91 编写于 作者: C chenjiaoAngel

fix conv int8 kernel choose and sooftmax compute bug

上级 77b3062a
......@@ -531,7 +531,7 @@ void softmax_inner1_large_axis<float>(const float* din,
}
float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax));
float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1));
for (j = 4 * j; j < axis_size; ++j) {
for (j = 4 * nn; j < axis_size; ++j) {
max_data = std::max(max_data, din_max_ptr[0]);
din_max_ptr++;
}
......@@ -557,7 +557,7 @@ void softmax_inner1_large_axis<float>(const float* din,
float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum));
float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1);
for (j = 4 * j; j < axis_size; ++j) {
for (j = 4 * nn; j < axis_size; ++j) {
dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data);
sum_data += dout_sum_ptr[0];
din_sum_ptr++;
......
......@@ -121,10 +121,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation &&
pads_equal) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DirectConv Int8";
} else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation &&
pads_equal) {
impl_ = new WinogradConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DirectConv Int8";
// VLOG(3) << "Run WinogradConv Int8";
} else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run GemmLikeConvInt8";
......@@ -168,10 +172,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DepthwiseConv Int8";
}else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation &&
pads_equal) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DirectConv Int8";
} else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation &&
pads_equal) {
impl_ = new WinogradConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DirectConv Int8";
// VLOG(3) << "Run WinogradConv Int8";
} else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run GemmLikeConvInt8";
......
......@@ -358,6 +358,9 @@ void WinogradConv<PRECISION(kInt8), OutType>::Run() {
param,
&ctx);
}
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_compute_2x2_3x3_int8";
#endif
}
template class WinogradConv<PRECISION(kInt8), PRECISION(kInt8)>;
template class WinogradConv<PRECISION(kInt8), PRECISION(kFloat)>;
......
......@@ -61,6 +61,13 @@ class WinogradConv<PRECISION(kInt8), OutType>
virtual void PrepareForRun();
virtual void ReInitWhenNeeded();
virtual void Run();
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForConvWino"};
#endif
protected:
using param_t = operators::ConvParam;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册