未验证 提交 53570d38 编写于 作者: H HappyAngel 提交者: GitHub

[cherry-pick] fix xiaodu crash and profiler (#3906)

* [arm]add 2x2s2p1 pooling  (#3705)

* fix pooling bug and speed

* add 2x2s2p1 pooling. test=develop

* fix conflict, test=develop

* fix conflict in wino

* [arm] add 3x3s1 Winograd int8 (#3767)

* fix: winograd support unsame pad
test=develop

* feat: add winograd int8 kernel
test=develop

* fix: style fix
test=develo

* fix winograd_int8 ut sgement default. test=develop

* close basic_test, test=develop
Co-authored-by: NMyPandaShaoxiang <txg4794@163.com>

* fix xiaodu crash in gemm prepacked

* in huwen phone, 3x3s2p0 avg pooling will rand crash, other phone does not have this feature

* [arm] update con int8 kernel choose (#3834)

* fix conv int8 kernel choose and sooftmax compute bug

* change axis_size = 4 kernel choose, test=develop

* fix format. test=develop

* fix format.test=develop

* fix build test=develop

* fix buiild error test=develop

* fix wino_int8 computte erroor. test=develop

* Update the link to debug, test=develop, test=document_fix (#3870) (#3871)
Co-authored-by: NMyPandaShaoxiang <txg4794@163.com>
Co-authored-by: Ncc <52520497+juncaipeng@users.noreply.github.com>
上级 1166948a
...@@ -49,4 +49,4 @@ $ ./opt \ ...@@ -49,4 +49,4 @@ $ ./opt \
## 五. 测试工具 ## 五. 测试工具
为了使您更好的了解并使用Lite框架,我们向有进一步使用需求的用户开放了 [Debug工具](debug#debug)[Profile工具](debug#profiler)。Lite Model Debug Tool可以用来查找Lite框架与PaddlePaddle框架在执行预测时模型中的对应变量值是否有差异,进一步快速定位问题Op,方便复现与排查问题。Profile Monitor Tool可以帮助您了解每个Op的执行时间消耗,其会自动统计Op执行的次数,最长、最短、平均执行时间等等信息,为性能调优做一个基础参考。您可以通过 [相关专题](debug) 了解更多内容。 为了使您更好的了解并使用Lite框架,我们向有进一步使用需求的用户开放了 [Debug工具](debug)[Profile工具](debug)。Lite Model Debug Tool可以用来查找Lite框架与PaddlePaddle框架在执行预测时模型中的对应变量值是否有差异,进一步快速定位问题Op,方便复现与排查问题。Profile Monitor Tool可以帮助您了解每个Op的执行时间消耗,其会自动统计Op执行的次数,最长、最短、平均执行时间等等信息,为性能调优做一个基础参考。您可以通过 [相关专题](debug) 了解更多内容。
...@@ -83,6 +83,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR) ...@@ -83,6 +83,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv5x5s2_depthwise_int8.cc conv5x5s2_depthwise_int8.cc
conv5x5s2_depthwise_fp32.cc conv5x5s2_depthwise_fp32.cc
conv3x3_winograd_fp32_c4.cc conv3x3_winograd_fp32_c4.cc
conv3x3_winograd_int8.cc
conv_winograd_3x3.cc conv_winograd_3x3.cc
conv_impl.cc conv_impl.cc
softmax.cc softmax.cc
......
...@@ -1245,7 +1245,7 @@ void weight_trans_c4_8x8( ...@@ -1245,7 +1245,7 @@ void weight_trans_c4_8x8(
for (int i = 0; i < ch_out * ch_in * 64; ++i) { for (int i = 0; i < ch_out * ch_in * 64; ++i) {
int new_c = i % 64; int new_c = i % 64;
int new_oc = i / ch_in / 64 / 4; int new_oc = i / ch_in / 64 / 4;
int new_ic = i / 64 % (ch_in * 4) % ch_in; int new_ic = i / 64 % ch_in;
int new_inner = i / ch_in / 64 % 4; int new_inner = i / ch_in / 64 % 4;
int dest_ind = int dest_ind =
new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner; new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner;
...@@ -1302,7 +1302,7 @@ void weight_trans_c4_4x4( ...@@ -1302,7 +1302,7 @@ void weight_trans_c4_4x4(
for (int i = 0; i < ch_out * ch_in * 16; ++i) { for (int i = 0; i < ch_out * ch_in * 16; ++i) {
int new_c = i % 16; int new_c = i % 16;
int new_oc = i / ch_in / 16 / 4; int new_oc = i / ch_in / 16 / 4;
int new_ic = i / 16 % (ch_in * 4) % ch_in; int new_ic = i / 16 % ch_in;
int new_inner = i / ch_in / 16 % 4; int new_inner = i / ch_in / 16 % 4;
int dest_ind = int dest_ind =
new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner; new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner;
......
此差异已折叠。
...@@ -3878,6 +3878,7 @@ inline void write_int32_nchwc8_to_nchw(const int* din, ...@@ -3878,6 +3878,7 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
int w_stride = we - ws; int w_stride = we - ws;
int valid_w = (we > width ? width : we) - ws; int valid_w = (we > width ? width : we) - ws;
int cnt = valid_w / 4; int cnt = valid_w / 4;
int remain = valid_w & 3;
float32x4_t w_scale0 = vld1q_f32(scale); float32x4_t w_scale0 = vld1q_f32(scale);
float32x4_t w_scale1 = vld1q_f32(scale + 4); float32x4_t w_scale1 = vld1q_f32(scale + 4);
...@@ -3933,10 +3934,10 @@ inline void write_int32_nchwc8_to_nchw(const int* din, ...@@ -3933,10 +3934,10 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
w_bias1, w_bias1,
flag_relu); flag_relu);
} }
if (we > width) { if (remain > 0) {
int offset = 32 * cnt; int offset = 32 * cnt;
din_hei_ptr = ptr_din + offset; din_hei_ptr = ptr_din + offset;
for (int j = ws + cnt * 4; j < width; ++j) { for (int j = 0; j < remain; ++j) {
if (flag_bias) { if (flag_bias) {
*(doutc0_ptr++) = *(doutc0_ptr++) =
cvt_kernel<Dtype>(din_hei_ptr[0], scale[0], bias[0], flag_relu); cvt_kernel<Dtype>(din_hei_ptr[0], scale[0], bias[0], flag_relu);
......
...@@ -359,6 +359,35 @@ void conv_compute_2x2_3x3_small(const float* input, ...@@ -359,6 +359,35 @@ void conv_compute_2x2_3x3_small(const float* input,
const float* bias, const float* bias,
const operators::ConvParam& param, const operators::ConvParam& param,
ARMContext* ctx); ARMContext* ctx);
void input_trans_c8_4x4_int8(const int8_t* src,
int src_stride,
int src_h_stride,
int16_t* dest,
int dest_stride,
int dest_h_stride);
void output_trans_c8_post_2x4_int8(const int32_t* src,
int src_stride,
int src_h_stride,
int32_t* dest,
int dest_stride,
int dest_h_stride);
void weight_trans_c8_4x4_int8(
int16_t* dest, const int8_t* src, int ic, int oc, void* workspace);
template <typename Dtype>
void conv_compute_2x2_3x3_int8(const int8_t* input,
Dtype* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int16_t* weight,
const float* bias,
const float* scale,
const operators::ConvParam& param,
ARMContext* ctx);
template <typename Dtype> template <typename Dtype>
void im2col(const Dtype* data_im, void im2col(const Dtype* data_im,
......
...@@ -1922,19 +1922,45 @@ void gemm_prepack_oth_int8(const int8_t* A_packed, ...@@ -1922,19 +1922,45 @@ void gemm_prepack_oth_int8(const int8_t* A_packed,
Dtype* tmp1 = nullptr; Dtype* tmp1 = nullptr;
Dtype* tmp2 = nullptr; Dtype* tmp2 = nullptr;
Dtype* tmp3 = nullptr; Dtype* tmp3 = nullptr;
float32_t scale_local[4]; float32_t scale_local[4] = {0, 0, 0, 0};
float32_t bias_local[4] = {0, 0, 0, 0}; float32_t bias_local[4] = {0, 0, 0, 0};
if (is_bias) { if (is_bias) {
bias_local[0] = bias[y]; if (y + 4 <= M) {
bias_local[1] = bias[y + 1]; bias_local[0] = bias[y];
bias_local[2] = bias[y + 2]; bias_local[1] = bias[y + 1];
bias_local[3] = bias[y + 3]; bias_local[2] = bias[y + 2];
bias_local[3] = bias[y + 3];
} else {
switch (M - y) {
case 3:
bias_local[2] = bias[y + 2];
case 2:
bias_local[1] = bias[y + 1];
case 1:
bias_local[0] = bias[y + 0];
default:
break;
}
}
} }
if (scale) { if (scale) {
scale_local[0] = scale[y]; if (y + 4 <= M) {
scale_local[1] = scale[y + 1]; scale_local[0] = scale[y];
scale_local[2] = scale[y + 2]; scale_local[1] = scale[y + 1];
scale_local[3] = scale[y + 3]; scale_local[2] = scale[y + 2];
scale_local[3] = scale[y + 3];
} else {
switch (M - y) {
case 3:
scale_local[2] = scale[y + 2];
case 2:
scale_local[1] = scale[y + 1];
case 1:
scale_local[0] = scale[y + 0];
default:
break;
}
}
} }
if (y + MBLOCK_INT8_OTH > M) { if (y + MBLOCK_INT8_OTH > M) {
switch (y + MBLOCK_INT8_OTH - M) { switch (y + MBLOCK_INT8_OTH - M) {
......
...@@ -54,6 +54,13 @@ void sgemm_prepack_c4_small(int M, ...@@ -54,6 +54,13 @@ void sgemm_prepack_c4_small(int M,
const float* B, const float* B,
float* C, float* C,
ARMContext* ctx); ARMContext* ctx);
void sgemm_prepack_c8_int16_small(int M,
int N,
int K,
const int16_t* A_packed,
const int16_t* B,
int32_t* C,
ARMContext* ctx);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
此差异已折叠。
...@@ -76,30 +76,55 @@ void pooling1x1s2p0_max(const float* din, ...@@ -76,30 +76,55 @@ void pooling1x1s2p0_max(const float* din,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2_max(const float* din, void pooling2x2s2p0_max(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win, int win,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2_avg(const float* din, void pooling2x2s2p0_avg(const float* din,
float* dout, float* dout,
int num, int num,
int chout, int chout,
int hout, int hout,
int wout, int wout,
int chin, int chin,
int hin, int hin,
int win, int win,
bool exclusive, bool exclusive,
int pad_bottom, int pad_bottom,
int pad_right); int pad_right);
void pooling2x2s2p1_max(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
int pad_bottom,
int pad_right);
void pooling2x2s2p1_avg(const float* din,
float* dout,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
bool exclusive,
int pad_bottom,
int pad_right);
void pooling3x3s1p1_max(const float* din, void pooling3x3s1p1_max(const float* din,
float* dout, float* dout,
......
...@@ -531,7 +531,7 @@ void softmax_inner1_large_axis<float>(const float* din, ...@@ -531,7 +531,7 @@ void softmax_inner1_large_axis<float>(const float* din,
} }
float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax)); float32x2_t vhmax = vmax_f32(vget_high_f32(vmax), vget_low_f32(vmax));
float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1)); float max_data = std::max(vget_lane_f32(vhmax, 0), vget_lane_f32(vhmax, 1));
for (j = 4 * j; j < axis_size; ++j) { for (j = 4 * nn; j < axis_size; ++j) {
max_data = std::max(max_data, din_max_ptr[0]); max_data = std::max(max_data, din_max_ptr[0]);
din_max_ptr++; din_max_ptr++;
} }
...@@ -557,7 +557,7 @@ void softmax_inner1_large_axis<float>(const float* din, ...@@ -557,7 +557,7 @@ void softmax_inner1_large_axis<float>(const float* din,
float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum)); float32x2_t vhsum = vadd_f32(vget_high_f32(vsum), vget_low_f32(vsum));
float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1); float sum_data = vget_lane_f32(vhsum, 0) + vget_lane_f32(vhsum, 1);
for (j = 4 * j; j < axis_size; ++j) { for (j = 4 * nn; j < axis_size; ++j) {
dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data); dout_sum_ptr[0] = expf(din_sum_ptr[0] - max_data);
sum_data += dout_sum_ptr[0]; sum_data += dout_sum_ptr[0];
din_sum_ptr++; din_sum_ptr++;
......
...@@ -50,13 +50,14 @@ class PoolingPE : public PE { ...@@ -50,13 +50,14 @@ class PoolingPE : public PE {
PoolingArgs args = {0}; PoolingArgs args = {0};
args.mode = param_.type; args.mode = param_.type;
auto paddings = *param_.paddings;
args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height)); args.kernel_reciprocal = fp32_2_fp16(1.0f / (k_width * k_height));
args.image.address = input->data<float16>(); args.image.address = input->data<float16>();
args.image.channels = input->shape().channel(); args.image.channels = input->shape().channel();
args.image.height = input->shape().height(); args.image.height = input->shape().height();
args.image.width = input->shape().width(); args.image.width = input->shape().width();
args.image.pad_height = param_.paddings[0]; args.image.pad_height = paddings[0];
args.image.pad_width = param_.paddings[1]; args.image.pad_width = paddings[2];
args.image.scale_address = input->scale(); args.image.scale_address = input->scale();
args.output.address = output->mutableData<float16>(); args.output.address = output->mutableData<float16>();
args.output.scale_address = output->scale(); args.output.scale_address = output->scale();
...@@ -69,8 +70,7 @@ class PoolingPE : public PE { ...@@ -69,8 +70,7 @@ class PoolingPE : public PE {
param_.poolingArgs = args; param_.poolingArgs = args;
// use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 // use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1
// && // && (k_width > 7 || k_height > 7);
// (k_width > 7 || k_height > 7);
use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 && use_cpu_ = output->shape().width() == 1 && output->shape().height() == 1 &&
(k_width > 255 || k_height > 255); (k_width > 255 || k_height > 255);
// use_cpu_ = param_.type == AVERAGE; // use_cpu_ = param_.type == AVERAGE;
...@@ -86,12 +86,13 @@ class PoolingPE : public PE { ...@@ -86,12 +86,13 @@ class PoolingPE : public PE {
float* image_addr = float_input.mutableData<float>(FP32, input->shape()); float* image_addr = float_input.mutableData<float>(FP32, input->shape());
float_input.copyFrom(input); float_input.copyFrom(input);
float16* data_out = output->data<float16>(); float16* data_out = output->data<float16>();
auto paddings = *param_.paddings;
int image_height = input->shape().height(); int image_height = input->shape().height();
int image_width = input->shape().width(); int image_width = input->shape().width();
int image_channels = input->shape().channel(); int image_channels = input->shape().channel();
int image_pad_h = param_.paddings[0]; int image_pad_h = paddings[0];
int image_pad_w = param_.paddings[1]; int image_pad_w = paddings[2];
int kernel_height = param_.kernelSize[1]; int kernel_height = param_.kernelSize[1];
int kernel_width = param_.kernelSize[0]; int kernel_width = param_.kernelSize[0];
int kernel_step_h = param_.strides[0]; int kernel_step_h = param_.strides[0];
......
...@@ -71,6 +71,9 @@ void ConcatCompute::Run() { ...@@ -71,6 +71,9 @@ void ConcatCompute::Run() {
auto* axis_tensor_data = axis_tensor->data<int>(); auto* axis_tensor_data = axis_tensor->data<int>();
axis = axis_tensor_data[0]; axis = axis_tensor_data[0];
} }
if (axis < 0) {
axis += inputs[0]->dims().size();
}
switch (inputs.front()->precision()) { switch (inputs.front()->precision()) {
case PRECISION(kFloat): case PRECISION(kFloat):
......
...@@ -73,7 +73,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() { ...@@ -73,7 +73,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
// VLOG(3) << "invoking dw conv"; // VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal && } else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal &&
no_dilation) { no_dilation) {
// TODO(MyPandaShaoxiang): winograd conv support any pad
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>; impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking winograd conv"; // VLOG(3) << "invoking winograd conv";
} else if (param.groups == 1 && kw == 3 && stride == 2 && } else if (param.groups == 1 && kw == 3 && stride == 2 &&
...@@ -122,10 +121,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() { ...@@ -122,10 +121,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
no_dilation && flag_dw) { no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>; impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DepthwiseConv Int8"; // VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && } else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation &&
ic * oc < 4 * hin * win && kps_equal && no_dilation) { pads_equal) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>; impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DirectConv Int8"; // VLOG(3) << "Run DirectConv Int8";
} else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation &&
pads_equal) {
impl_ = new WinogradConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run WinogradConv Int8";
} else { } else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>; impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run GemmLikeConvInt8"; // VLOG(3) << "Run GemmLikeConvInt8";
...@@ -169,10 +172,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() { ...@@ -169,10 +172,14 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
no_dilation && flag_dw) { no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>; impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DepthwiseConv Int8"; // VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) && } else if (param.groups == 1 && kw == 3 && sw == 2 && no_dilation &&
ic * oc < 4 * hin * win && kps_equal && no_dilation) { pads_equal) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>; impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DirectConv Int8"; // VLOG(3) << "Run DirectConv Int8";
} else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation &&
pads_equal) {
impl_ = new WinogradConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run WinogradConv Int8";
} else { } else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>; impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run GemmLikeConvInt8"; // VLOG(3) << "Run GemmLikeConvInt8";
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#include "lite/kernels/arm/conv_winograd.h" #include "lite/kernels/arm/conv_winograd.h"
#include <vector>
#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/packed_sgemm.h" #include "lite/backends/arm/math/packed_sgemm.h"
...@@ -166,6 +165,189 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -166,6 +165,189 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
} }
} }
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::ReInitWhenNeeded() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
int threads = ctx.threads();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
if (last_shape_ == x_dims) {
return;
}
last_shape_ = x_dims;
//! update workspace size
int ic = x_dims[1];
int ih = x_dims[2];
int iw = x_dims[3];
int oc = o_dims[1];
int oh = o_dims[2];
int ow = o_dims[3];
int tile_block = 8;
auto pad = *(param.paddings);
int pad_h0 = pad[0];
int pad_h1 = pad[1];
int pad_w0 = pad[2];
int pad_w1 = pad[3];
int oc_pad = (oc + 7) / 8 * 8;
int ic_pad = (ic + 7) / 8 * 8;
const int new_input_size =
ic_pad * (ih + pad_h0 + pad_h1) * (iw + pad_w0 + pad_w1) +
oc_pad * oh * ow * sizeof(int32_t);
int tmp_input_thread_size_byte =
tile_block * ic_pad * wino_iw * wino_iw * sizeof(int16_t);
int tmp_output_thread_size_byte =
tile_block * oc_pad * wino_iw * wino_iw * sizeof(int32_t);
const int temp_size =
(tmp_input_thread_size_byte + tmp_output_thread_size_byte +
wino_iw * wino_iw * (8 + 8 * sizeof(int32_t))) *
threads;
workspace_size_ = temp_size + new_input_size;
//! update trans weights impl
// choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false;
// we only support 2x2 now
choose_small_ = true;
float w_fact = 0.25;
if (choose_small_) {
wino_iw = 4;
if (last_function_ == 0) {
return;
}
last_function_ = 0;
} else {
wino_iw = 6;
if (last_function_ == 1) {
return;
}
last_function_ = 1;
}
/// update scale
for (auto& ws : w_scale_) {
ws *= w_fact;
}
weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad});
void* trans_tmp_ptr = malloc(sizeof(int16_t) * wino_iw * wino_iw * oc * ic);
auto weights_data_ = weights_.mutable_data<int16_t>();
if (!choose_small_) {
} else {
lite::arm::math::weight_trans_c8_4x4_int8(
weights_data_,
param.filter->template data<int8_t>(),
ic,
oc,
trans_tmp_ptr);
}
free(trans_tmp_ptr);
}
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::PrepareForRun() {
auto& param = this->Param<param_t>();
w_scale_ = param.weight_scale;
if (w_scale_.size() != 1 && w_scale_.size() != param.filter->dims()[0]) {
LOG(FATAL) << "weights scale size must equal to filter size";
return;
}
if (w_scale_.size() == 1) {
for (int i = 0; i < param.filter->dims()[0] - 1; ++i) {
w_scale_.push_back(w_scale_[0]);
}
}
float input_scale = param.input_scale;
for (auto& ws : w_scale_) {
ws *= input_scale;
}
if (param.bias) {
bias_.Resize(param.bias->dims());
auto ptr = bias_.mutable_data<float>();
auto ptr_in = param.bias->template data<float>();
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] = ptr_in[i];
}
}
if (OutType == PRECISION(kInt8)) {
float output_scale = param.output_scale;
for (auto& ws : w_scale_) {
ws /= output_scale;
}
if (param.bias) {
auto ptr = bias_.mutable_data<float>();
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] /= output_scale;
}
}
}
ReInitWhenNeeded();
}
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
ctx.ExtendWorkspace(workspace_size_);
const auto* i_data = param.x->template data<int8_t>();
const auto* w_data = weights_.data<int16_t>();
const auto* b_data = param.bias ? bias_.data<float>() : nullptr;
// const float* i_data;
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
// now always choose small
if (OutType == PRECISION(kInt8)) {
auto* o_data = param.output->template mutable_data<int8_t>();
lite::arm::math::conv_compute_2x2_3x3_int8<int8_t>(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
w_scale_.data(),
param,
&ctx);
} else {
auto* o_data = param.output->template mutable_data<float>();
lite::arm::math::conv_compute_2x2_3x3_int8<float>(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
w_scale_.data(),
param,
&ctx);
}
#ifdef LITE_WITH_PROFILE
kernel_func_name_ = "conv_compute_2x2_3x3_int8";
#endif
}
template class WinogradConv<PRECISION(kInt8), PRECISION(kInt8)>;
template class WinogradConv<PRECISION(kInt8), PRECISION(kFloat)>;
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -15,11 +15,12 @@ ...@@ -15,11 +15,12 @@
#pragma once #pragma once
#include <cmath> #include <cmath>
#include <string>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h" #include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h" #include "lite/core/context.h"
#include "lite/core/kernel.h" #include "lite/core/kernel.h"
#include "lite/core/target_wrapper.h" #include "lite/core/target_wrapper.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace kernels { namespace kernels {
...@@ -44,7 +45,34 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> { ...@@ -44,7 +45,34 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
bool choose_small_{false}; bool choose_small_{false};
int wino_iw{8}; int wino_iw{8};
}; };
template <PrecisionType OutType>
class WinogradConv<PRECISION(kInt8), OutType>
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
WinogradConv() = default;
~WinogradConv() {}
virtual void PrepareForRun();
virtual void ReInitWhenNeeded();
virtual void Run();
#ifdef LITE_WITH_PROFILE
virtual void SetProfileRuntimeKernelInfo(
paddle::lite::profile::OpCharacter* ch) {
ch->kernel_func_name = kernel_func_name_;
}
std::string kernel_func_name_{"NotImplForConvWino"};
#endif
protected:
using param_t = operators::ConvParam;
Tensor weights_;
Tensor bias_;
DDim last_shape_;
int workspace_size_{0};
int last_function_{-1};
bool choose_small_{true};
int wino_iw{4};
std::vector<float> w_scale_;
};
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -58,6 +58,7 @@ void PoolCompute::Run() { ...@@ -58,6 +58,7 @@ void PoolCompute::Run() {
bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) && bool global_pooling = (paddings[0] == 0) && (ksize[0] == in_dims[2]) &&
(ksize[1] == in_dims[3]) && kps_equal && pads_equal; (ksize[1] == in_dims[3]) && kps_equal && pads_equal;
global_pooling = param.global_pooling || global_pooling; global_pooling = param.global_pooling || global_pooling;
if (global_pooling) { if (global_pooling) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
paddings[2 * i] = 0; paddings[2 * i] = 0;
...@@ -107,35 +108,65 @@ void PoolCompute::Run() { ...@@ -107,35 +108,65 @@ void PoolCompute::Run() {
} else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) { kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling2x2s2_max(din, lite::arm::math::pooling2x2s2p0_max(din,
dout, dout,
out_dims[0], out_dims[0],
out_dims[1], out_dims[1],
out_dims[2], out_dims[2],
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3], in_dims[3],
paddings[1], paddings[1],
paddings[3]); paddings[3]);
return; return;
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2_avg(din, lite::arm::math::pooling2x2s2p0_avg(din,
dout, dout,
out_dims[0], out_dims[0],
out_dims[1], out_dims[1],
out_dims[2], out_dims[2],
out_dims[3], out_dims[3],
in_dims[1], in_dims[1],
in_dims[2], in_dims[2],
in_dims[3], in_dims[3],
exclusive, exclusive,
paddings[1], paddings[1],
paddings[3]); paddings[3]);
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 && } else if (ksize[0] == 2 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) { kps_equal) {
if (pooling_type == "max") {
lite::arm::math::pooling2x2s2p1_max(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
paddings[1],
paddings[3]);
return;
} else if (pooling_type == "avg") {
lite::arm::math::pooling2x2s2p1_avg(din,
dout,
out_dims[0],
out_dims[1],
out_dims[2],
out_dims[3],
in_dims[1],
in_dims[2],
in_dims[3],
exclusive,
paddings[1],
paddings[3]);
return;
}
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 1 &&
pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p1_max(din, lite::arm::math::pooling3x3s1p1_max(din,
dout, dout,
...@@ -165,7 +196,7 @@ void PoolCompute::Run() { ...@@ -165,7 +196,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 1 && paddings[0] == 0 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s1p0_max(din, lite::arm::math::pooling3x3s1p0_max(din,
dout, dout,
...@@ -195,7 +226,7 @@ void PoolCompute::Run() { ...@@ -195,7 +226,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 0 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p0_max(din, lite::arm::math::pooling3x3s2p0_max(din,
dout, dout,
...@@ -225,7 +256,7 @@ void PoolCompute::Run() { ...@@ -225,7 +256,7 @@ void PoolCompute::Run() {
return; return;
} }
} else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 && } else if (ksize[0] == 3 && strides[0] == 2 && paddings[0] == 1 &&
kps_equal) { pads_equal && kps_equal) {
if (pooling_type == "max") { if (pooling_type == "max") {
lite::arm::math::pooling3x3s2p1_max(din, lite::arm::math::pooling3x3s2p1_max(din,
dout, dout,
......
...@@ -34,7 +34,7 @@ void SoftmaxCompute::Run() { ...@@ -34,7 +34,7 @@ void SoftmaxCompute::Run() {
int inner_num = x_dims.Slice(axis + 1, x_rank).production(); int inner_num = x_dims.Slice(axis + 1, x_rank).production();
int axis_size = x_dims[axis]; int axis_size = x_dims[axis];
if (inner_num == 1) { if (inner_num == 1) {
if (axis_size >= 4) { if (axis_size > 4) {
lite::arm::math::softmax_inner1_large_axis( lite::arm::math::softmax_inner1_large_axis(
din, dout, outer_num, axis_size); din, dout, outer_num, axis_size);
} else { } else {
......
...@@ -34,7 +34,7 @@ DEFINE_int32(power_mode, ...@@ -34,7 +34,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num"); DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times"); DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times"); DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, true, "do all tests"); DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result"); DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size"); DEFINE_int32(batch, 1, "batch size");
...@@ -59,6 +59,7 @@ DEFINE_bool(flag_bias, true, "with bias"); ...@@ -59,6 +59,7 @@ DEFINE_bool(flag_bias, true, "with bias");
typedef paddle::lite::DDim DDim; typedef paddle::lite::DDim DDim;
typedef paddle::lite::Tensor Tensor; typedef paddle::lite::Tensor Tensor;
typedef paddle::lite::operators::ConvParam ConvParam; typedef paddle::lite::operators::ConvParam ConvParam;
typedef paddle::lite::operators::ActivationParam ActivationParam;
using paddle::lite::profile::Timer; using paddle::lite::profile::Timer;
DDim compute_out_dim(const DDim& dim_in, DDim compute_out_dim(const DDim& dim_in,
...@@ -165,7 +166,18 @@ void test_conv_int8(const std::vector<DDim>& input_dims, ...@@ -165,7 +166,18 @@ void test_conv_int8(const std::vector<DDim>& input_dims,
param_fp32_out.bias->CopyDataFrom(*param_int8_out.bias); param_fp32_out.bias->CopyDataFrom(*param_int8_out.bias);
bias_fp32.CopyDataFrom(*param_int8_out.bias); bias_fp32.CopyDataFrom(*param_int8_out.bias);
} }
if (flag_relu) {
ActivationParam act_param;
act_param.has_active = true;
act_param.active_type = (paddle::lite_api::ActivationType)
flag_relu; // 1-relu, 2-relu6, 4-leakyrelu
if (flag_relu) {
param_fp32_out.fuse_relu = true;
param_int8_out.fuse_relu = true;
}
param_fp32_out.activation_param = act_param;
param_int8_out.activation_param = act_param;
}
std::vector<float> scale_in{1.f / 127}; std::vector<float> scale_in{1.f / 127};
std::vector<float> scale_out{weight_dim.count(1, 4) / 127.f}; std::vector<float> scale_out{weight_dim.count(1, 4) / 127.f};
std::vector<float> scale_w(weight_dim[0], 1.f / 127); std::vector<float> scale_w(weight_dim[0], 1.f / 127);
...@@ -580,6 +592,9 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) { ...@@ -580,6 +592,9 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
dims.push_back(DDim({batch, cin, h, h})); dims.push_back(DDim({batch, cin, h, h}));
} }
} }
if (cin == 1 && cout == 1) {
continue;
}
test_conv_int8(dims, test_conv_int8(dims,
weights_dim, weights_dim,
1, 1,
......
...@@ -179,6 +179,141 @@ bool test_sgemm_c4( ...@@ -179,6 +179,141 @@ bool test_sgemm_c4(
#endif #endif
return true; return true;
} }
bool test_sgemm_c8(
int m, int n, int k, bool has_bias, bool has_relu, int cls, int ths) {
int m_round = (m + 7) / 8 * 8;
int k_round = (k + 7) / 8 * 8;
int size_a = m * k;
int size_b = n * k;
int size_a_c4 = m_round * k_round;
int size_b_c8 = k_round * n;
Tensor ta;
Tensor tb;
Tensor ta_c4;
Tensor tb_c8;
Tensor tc;
Tensor tc_basic;
Tensor tc_backup;
Tensor tbias;
ta.Resize({size_a});
tb.Resize({size_b});
ta_c4.Resize({size_a_c4});
tb_c8.Resize({size_b_c8});
tc.Resize({m_round * n});
tc_basic.Resize({m_round * n});
tbias.Resize({m});
ta.set_precision(PRECISION(kInt16));
tb.set_precision(PRECISION(kInt16));
ta_c4.set_precision(PRECISION(kInt16));
tb_c8.set_precision(PRECISION(kInt16));
tc.set_precision(PRECISION(kInt32));
tc_basic.set_precision(PRECISION(kInt32));
tbias.set_precision(PRECISION(kInt32));
fill_tensor_rand(ta);
fill_tensor_rand(tb);
fill_tensor_rand(tbias);
fill_tensor_rand(tc);
auto da = ta.mutable_data<int16_t>();
auto db = tb.mutable_data<int16_t>();
auto da_c4 = ta_c4.mutable_data<int16_t>();
auto db_c8 = tb_c8.mutable_data<int16_t>();
auto dc_basic = tc_basic.mutable_data<int32_t>();
auto dbias = tbias.mutable_data<int32_t>();
// trans A, B to c4
basic_trans_mat_to_c8(da, da_c4, k, m, k, true);
basic_trans_mat_to_c8(db, db_c8, n, k, n, false);
LOG(INFO) << "sgemm_c8 M: " << m << ", N: " << n << ", K: " << k
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
if (FLAGS_check_result) {
basic_gemm_c8(false,
false,
m,
n,
k,
1,
da,
k,
db,
n,
0,
dc_basic,
n,
dbias,
false,
false);
}
Timer t0;
LOG(INFO) << "basic test end";
#ifdef LITE_WITH_ARM
//! compute
double ops = 2.0 * m_round * n * k_round;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
auto dc = tc.mutable_data<int32_t>();
for (int j = 0; j < FLAGS_warmup; ++j) {
paddle::lite::arm::math::sgemm_prepack_c8_int16_small(
m, n, k, da_c4, db_c8, dc, &ctx);
}
LOG(INFO) << "basic test end";
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.Start();
paddle::lite::arm::math::sgemm_prepack_c8_int16_small(
m, n, k, da_c4, db_c8, dc, &ctx);
t0.Stop();
}
LOG(INFO) << "basic test end";
LOG(INFO) << "M: " << m << ", N: " << n << ", K: " << k
<< ", power_mode: " << cls << ", threads: " << ths
<< ", GOPS: " << ops * 1e-9f
<< " GOPS, avg time: " << t0.LapTimes().Avg()
<< " ms, min time: " << t0.LapTimes().Min()
<< " ms, mean GOPs: " << ops * 1e-6f / t0.LapTimes().Avg()
<< " GOPs, max GOPs: " << ops * 1e-6f / t0.LapTimes().Min()
<< " GOPs";
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tc_basic, tc, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) {
Tensor tdiff;
tdiff.set_precision(PRECISION(kInt32));
tdiff.Resize(tc.dims());
tensor_diff(tc_basic, tc, tdiff);
LOG(INFO) << "a: ";
print_tensor(ta);
LOG(INFO) << "a_c8: ";
print_tensor(ta_c4);
LOG(INFO) << "b: ";
print_tensor(tb);
LOG(INFO) << "b_c8: ";
print_tensor(tb_c8);
LOG(INFO) << "basic result: ";
print_tensor(tc_basic);
LOG(INFO) << "lite result: ";
print_tensor(tc);
LOG(INFO) << "diff result: ";
print_tensor(tdiff);
return false;
}
}
#endif
return true;
}
TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) { TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
if (FLAGS_basic_test) { if (FLAGS_basic_test) {
...@@ -186,11 +321,11 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) { ...@@ -186,11 +321,11 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
paddle::lite::DeviceInfo::Init(); paddle::lite::DeviceInfo::Init();
#endif #endif
LOG(INFO) << "run basic sgemm_c4 test"; LOG(INFO) << "run basic sgemm_c4 test";
for (auto& m : {1, 3, 8, 32, 397}) { for (auto& m : {1, 3, 8, 32, 397, 32, 64, 77}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789}) { for (auto& n : {1, 2, 3, 4, 13, 141, 789, 1}) {
for (auto& k : {1, 3, 8, 59, 234}) { for (auto& k : {1, 3, 8, 59, 234, 19}) {
for (auto& has_bias : {false, true}) { for (auto& has_bias : {false}) {
for (auto& has_relu : {false, true}) { for (auto& has_relu : {false}) {
for (auto& th : {1, 2, 4}) { for (auto& th : {1, 2, 4}) {
auto flag = test_sgemm_c4( auto flag = test_sgemm_c4(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th); m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
...@@ -213,8 +348,41 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) { ...@@ -213,8 +348,41 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
} }
} }
} }
TEST(TestSgemmC8, test_func_sgemm_c8_prepacked) {
if (FLAGS_basic_test) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LOG(INFO) << "run basic sgemm_c4 test";
for (auto& m : {1, 3, 8, 32, 397, 32, 64, 77}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789, 1}) {
for (auto& k : {1, 3, 8, 59, 234, 19}) {
for (auto& has_bias : {false}) {
for (auto& has_relu : {false}) {
for (auto& th : {1}) {
auto flag = test_sgemm_c8(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " failed\n";
}
}
}
}
}
}
}
}
}
TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) { TEST(TestSgemmCnCustom, test_func_sgemm_cn_prepacked_custom) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init(); paddle::lite::DeviceInfo::Init();
#endif #endif
...@@ -230,6 +398,18 @@ TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) { ...@@ -230,6 +398,18 @@ TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) {
<< ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias << ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!"; << ", relu: " << FLAGS_flag_relu << " failed!!";
} }
flag = test_sgemm_c8(FLAGS_M,
FLAGS_N,
FLAGS_K,
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_power_mode,
FLAGS_threads);
if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
<< ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
}
LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K
<< ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu << ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu
<< " passed!!"; << " passed!!";
......
...@@ -60,6 +60,72 @@ static void basic_trans_mat_to_c4(const type* input, ...@@ -60,6 +60,72 @@ static void basic_trans_mat_to_c4(const type* input,
} }
} }
} }
template <typename type>
static void basic_trans_mat_to_c8(const type* input,
type* output,
const int ldin,
const int M,
const int K,
bool pack_k) {
const int m_round = (M + 7) / 8 * 8;
int k_round = (K + 7) / 8 * 8;
if (!pack_k) {
k_round = K;
}
const int m_loop = m_round / 8;
type zero_buf[K];
memset(zero_buf, 0, K * sizeof(type));
for (int i = 0; i < m_loop; ++i) {
const type* in0 = input + i * 8 * ldin;
const type* in1 = in0 + ldin;
const type* in2 = in1 + ldin;
const type* in3 = in2 + ldin;
const type* in4 = in3 + ldin;
const type* in5 = in4 + ldin;
const type* in6 = in5 + ldin;
const type* in7 = in6 + ldin;
if (8 * (i + 1) - M > 0) {
switch (8 * (i + 1) - M) {
case 7:
in1 = zero_buf;
case 6:
in2 = zero_buf;
case 5:
in3 = zero_buf;
case 4:
in4 = zero_buf;
case 3:
in5 = zero_buf;
case 2:
in6 = zero_buf;
case 1:
in7 = zero_buf;
default:
break;
}
}
for (int j = 0; j < K; ++j) {
*output++ = *in0++;
*output++ = *in1++;
*output++ = *in2++;
*output++ = *in3++;
*output++ = *in4++;
*output++ = *in5++;
*output++ = *in6++;
*output++ = *in7++;
}
for (int j = K; j < k_round; ++j) {
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
}
}
}
template <typename type, typename type2> template <typename type, typename type2>
static void basic_gemm_c4(bool trans_a, static void basic_gemm_c4(bool trans_a,
...@@ -116,6 +182,60 @@ static void basic_gemm_c4(bool trans_a, ...@@ -116,6 +182,60 @@ static void basic_gemm_c4(bool trans_a,
free(tmp_c); free(tmp_c);
} }
template <typename type, typename type2>
static void basic_gemm_c8(bool trans_a,
bool trans_b,
int m,
int n,
int k,
type2 alpha,
const type* a,
int lda,
const type* b,
int ldb,
type2 beta,
type2* c,
int ldc,
const type2* bias,
bool flag_bias = false,
bool flag_relu = false) {
type2* tmp_c = reinterpret_cast<type2*>(malloc(m * ldc * sizeof(type2)));
memset(tmp_c, 0, m * ldc * sizeof(type2));
#pragma omp parallel for
for (int i = 0; i < m; ++i) {
auto bias_data = static_cast<type2>(0);
if (flag_bias) {
bias_data = bias[i];
}
for (int j = 0; j < n; ++j) {
auto sum = static_cast<type2>(0);
for (int l = 0; l < k; ++l) {
type av;
type bv;
if (trans_a) {
av = a[l * lda + i];
} else {
av = a[i * lda + l];
}
if (trans_b) {
bv = b[j * ldb + l];
} else {
bv = b[l * ldb + j];
}
sum += av * bv;
}
type2 tmp = alpha * sum + beta * tmp_c[i * ldc + j] + bias_data;
if (flag_relu) {
tmp_c[i * ldc + j] = tmp > (type2)0 ? tmp : (type2)0;
} else {
tmp_c[i * ldc + j] = tmp;
}
}
}
//! trans c to c4
basic_trans_mat_to_c8(tmp_c, c, ldc, m, n, false);
free(tmp_c);
}
template <typename type, typename type2> template <typename type, typename type2>
static void basic_gemm(bool trans_a, static void basic_gemm(bool trans_a,
bool trans_b, bool trans_b,
......
...@@ -41,6 +41,10 @@ void fill_tensor_const(Tensor& tensor, float value) { // NOLINT ...@@ -41,6 +41,10 @@ void fill_tensor_const(Tensor& tensor, float value) { // NOLINT
fill_tensor_host_const_impl( fill_tensor_host_const_impl(
tensor.mutable_data<int8_t>(), static_cast<signed char>(value), size); tensor.mutable_data<int8_t>(), static_cast<signed char>(value), size);
break; break;
case PRECISION(kInt16):
fill_tensor_host_const_impl(
tensor.mutable_data<int16_t>(), static_cast<int16_t>(value), size);
break;
case PRECISION(kInt32): case PRECISION(kInt32):
fill_tensor_host_const_impl( fill_tensor_host_const_impl(
tensor.mutable_data<int>(), static_cast<int>(value), size); tensor.mutable_data<int>(), static_cast<int>(value), size);
...@@ -69,6 +73,12 @@ void fill_tensor_host_rand_impl<signed char>(signed char* dio, int64_t size) { ...@@ -69,6 +73,12 @@ void fill_tensor_host_rand_impl<signed char>(signed char* dio, int64_t size) {
} }
} }
template <> template <>
void fill_tensor_host_rand_impl<int16_t>(int16_t* dio, int64_t size) {
for (int64_t i = 0; i < size; ++i) {
dio[i] = (rand() % 256 - 128) * 2; // NOLINT
}
}
template <>
void fill_tensor_host_rand_impl<unsigned char>(unsigned char* dio, void fill_tensor_host_rand_impl<unsigned char>(unsigned char* dio,
int64_t size) { int64_t size) {
for (int64_t i = 0; i < size; ++i) { for (int64_t i = 0; i < size; ++i) {
...@@ -86,6 +96,9 @@ void fill_tensor_rand(Tensor& tensor) { // NOLINT ...@@ -86,6 +96,9 @@ void fill_tensor_rand(Tensor& tensor) { // NOLINT
case PRECISION(kInt8): case PRECISION(kInt8):
fill_tensor_host_rand_impl(tensor.mutable_data<int8_t>(), size); fill_tensor_host_rand_impl(tensor.mutable_data<int8_t>(), size);
break; break;
case PRECISION(kInt16):
fill_tensor_host_rand_impl(tensor.mutable_data<int16_t>(), size);
break;
case PRECISION(kInt32): case PRECISION(kInt32):
fill_tensor_host_rand_impl(tensor.mutable_data<int>(), size); fill_tensor_host_rand_impl(tensor.mutable_data<int>(), size);
break; break;
......
...@@ -678,15 +678,9 @@ void resize(const uint8_t* src, ...@@ -678,15 +678,9 @@ void resize(const uint8_t* src,
} else if (srcFormat == NV12 || srcFormat == NV21) { } else if (srcFormat == NV12 || srcFormat == NV21) {
nv21_resize(src, dst, srcw, srch, dstw, dsth); nv21_resize(src, dst, srcw, srch, dstw, dsth);
return; return;
num = 1;
int hout = static_cast<int>(0.5 * dsth);
dsth += hout;
} else if (srcFormat == BGR || srcFormat == RGB) { } else if (srcFormat == BGR || srcFormat == RGB) {
bgr_resize(src, dst, srcw, srch, dstw, dsth); bgr_resize(src, dst, srcw, srch, dstw, dsth);
return; return;
w_in = srcw * 3;
w_out = dstw * 3;
num = 3;
} else if (srcFormat == BGRA || srcFormat == RGBA) { } else if (srcFormat == BGRA || srcFormat == RGBA) {
w_in = srcw * 4; w_in = srcw * 4;
w_out = dstw * 4; w_out = dstw * 4;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册