未验证 提交 3b4eee0f 编写于 作者: H HappyAngel 提交者: GitHub

[arm] add 3x3s1 Winograd int8 (#3767)

* fix: winograd support unsame pad
test=develop

* feat: add winograd int8 kernel
test=develop

* fix: style fix
test=develo

* fix winograd_int8 ut sgement default. test=develop

* close basic_test, test=develop
Co-authored-by: NMyPandaShaoxiang <txg4794@163.com>
上级 8e400754
......@@ -83,6 +83,7 @@ if (NOT HAS_ARM_MATH_LIB_DIR)
conv5x5s2_depthwise_int8.cc
conv5x5s2_depthwise_fp32.cc
conv3x3_winograd_fp32_c4.cc
conv3x3_winograd_int8.cc
conv_winograd_3x3.cc
conv_impl.cc
softmax.cc
......
......@@ -1245,7 +1245,7 @@ void weight_trans_c4_8x8(
for (int i = 0; i < ch_out * ch_in * 64; ++i) {
int new_c = i % 64;
int new_oc = i / ch_in / 64 / 4;
int new_ic = i / 64 % (ch_in * 4) % ch_in;
int new_ic = i / 64 % ch_in;
int new_inner = i / ch_in / 64 % 4;
int dest_ind =
new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner;
......@@ -1302,7 +1302,7 @@ void weight_trans_c4_4x4(
for (int i = 0; i < ch_out * ch_in * 16; ++i) {
int new_c = i % 16;
int new_oc = i / ch_in / 16 / 4;
int new_ic = i / 16 % (ch_in * 4) % ch_in;
int new_ic = i / 16 % ch_in;
int new_inner = i / ch_in / 16 % 4;
int dest_ind =
new_c * c_stride + new_oc * ic_pad * 4 + new_ic * 4 + new_inner;
......
此差异已折叠。
......@@ -3762,6 +3762,7 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
int w_stride = we - ws;
int valid_w = (we > width ? width : we) - ws;
int cnt = valid_w / 4;
int remain = valid_w & 3;
float32x4_t w_scale0 = vld1q_f32(scale);
float32x4_t w_scale1 = vld1q_f32(scale + 4);
......@@ -3818,10 +3819,10 @@ inline void write_int32_nchwc8_to_nchw(const int* din,
flag_act,
alpha);
}
if (we > width) {
if (remain > 0) {
int offset = 32 * cnt;
din_hei_ptr = ptr_din + offset;
for (int j = ws + cnt * 4; j < width; ++j) {
for (int j = 0; j < remain; ++j) {
if (flag_bias) {
*(doutc0_ptr++) = cvt_kernel<Dtype>(
din_hei_ptr[0], scale[0], bias[0], flag_act, alpha[0]);
......
......@@ -359,6 +359,35 @@ void conv_compute_2x2_3x3_small(const float* input,
const float* bias,
const operators::ConvParam& param,
ARMContext* ctx);
void input_trans_c8_4x4_int8(const int8_t* src,
int src_stride,
int src_h_stride,
int16_t* dest,
int dest_stride,
int dest_h_stride);
void output_trans_c8_post_2x4_int8(const int32_t* src,
int src_stride,
int src_h_stride,
int32_t* dest,
int dest_stride,
int dest_h_stride);
void weight_trans_c8_4x4_int8(
int16_t* dest, const int8_t* src, int ic, int oc, void* workspace);
template <typename Dtype>
void conv_compute_2x2_3x3_int8(const int8_t* input,
Dtype* output,
int num,
int chout,
int hout,
int wout,
int chin,
int hin,
int win,
const int16_t* weight,
const float* bias,
const float* scale,
const operators::ConvParam& param,
ARMContext* ctx);
template <typename Dtype>
void im2col(const Dtype* data_im,
......
......@@ -54,6 +54,13 @@ void sgemm_prepack_c4_small(int M,
const float* B,
float* C,
ARMContext* ctx);
void sgemm_prepack_c8_int16_small(int M,
int N,
int K,
const int16_t* A_packed,
const int16_t* B,
int32_t* C,
ARMContext* ctx);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -73,7 +73,6 @@ void ConvCompute<PRECISION(kFloat), PRECISION(kFloat)>::PrepareForRun() {
// VLOG(3) << "invoking dw conv";
} else if (param.groups == 1 && kw == 3 && stride == 1 && ks_equal &&
no_dilation) {
// TODO(MyPandaShaoxiang): winograd conv support any pad
impl_ = new WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>;
// VLOG(3) << "invoking winograd conv";
} else if (param.groups == 1 && kw == 3 && stride == 2 &&
......@@ -122,9 +121,9 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kFloat)>::PrepareForRun() {
no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
ic * oc < 4 * hin * win && kps_equal && no_dilation) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kFloat)>;
} else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation &&
pads_equal) {
impl_ = new WinogradConv<PRECISION(kInt8), PRECISION(kFloat)>;
// VLOG(3) << "Run DirectConv Int8";
} else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kFloat)>;
......@@ -169,9 +168,9 @@ void ConvCompute<PRECISION(kInt8), PRECISION(kInt8)>::PrepareForRun() {
no_dilation && flag_dw) {
impl_ = new DepthwiseConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DepthwiseConv Int8";
} else if (param.groups == 1 && kw == 3 && (sw == 1 || sw == 2) &&
ic * oc < 4 * hin * win && kps_equal && no_dilation) {
impl_ = new DirectConv<PRECISION(kInt8), PRECISION(kInt8)>;
} else if (param.groups == 1 && kw == 3 && sw == 1 && no_dilation &&
pads_equal) {
impl_ = new WinogradConv<PRECISION(kInt8), PRECISION(kInt8)>;
// VLOG(3) << "Run DirectConv Int8";
} else {
impl_ = new GemmLikeConv<PRECISION(kInt8), PRECISION(kInt8)>;
......
......@@ -13,7 +13,6 @@
// limitations under the License.
#include "lite/kernels/arm/conv_winograd.h"
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/backends/arm/math/packed_sgemm.h"
......@@ -183,6 +182,186 @@ void WinogradConv<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
}
}
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::ReInitWhenNeeded() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
int threads = ctx.threads();
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
if (last_shape_ == x_dims) {
return;
}
last_shape_ = x_dims;
//! update workspace size
int ic = x_dims[1];
int ih = x_dims[2];
int iw = x_dims[3];
int oc = o_dims[1];
int oh = o_dims[2];
int ow = o_dims[3];
int tile_block = 8;
auto pad = *(param.paddings);
int pad_h0 = pad[0];
int pad_h1 = pad[1];
int pad_w0 = pad[2];
int pad_w1 = pad[3];
int oc_pad = (oc + 7) / 8 * 8;
int ic_pad = (ic + 7) / 8 * 8;
const int new_input_size =
ic_pad * (ih + pad_h0 + pad_h1) * (iw + pad_w0 + pad_w1) +
oc_pad * oh * ow * sizeof(int32_t);
int tmp_input_thread_size_byte =
tile_block * ic_pad * wino_iw * wino_iw * sizeof(int16_t);
int tmp_output_thread_size_byte =
tile_block * oc_pad * wino_iw * wino_iw * sizeof(int32_t);
const int temp_size =
(tmp_input_thread_size_byte + tmp_output_thread_size_byte +
wino_iw * wino_iw * (8 + 8 * sizeof(int32_t))) *
threads;
workspace_size_ = temp_size + new_input_size;
//! update trans weights impl
// choose_small_ = ow * oh / (tile_block * threads) < 36 ? true : false;
// we only support 2x2 now
choose_small_ = true;
float w_fact = 0.25;
if (choose_small_) {
wino_iw = 4;
if (last_function_ == 0) {
return;
}
last_function_ = 0;
} else {
wino_iw = 6;
if (last_function_ == 1) {
return;
}
last_function_ = 1;
}
/// update scale
for (auto& ws : w_scale_) {
ws *= w_fact;
}
weights_.Resize({1, 1, 1, wino_iw * wino_iw * oc_pad * ic_pad});
void* trans_tmp_ptr = malloc(sizeof(int16_t) * wino_iw * wino_iw * oc * ic);
auto weights_data_ = weights_.mutable_data<int16_t>();
if (!choose_small_) {
} else {
lite::arm::math::weight_trans_c8_4x4_int8(
weights_data_,
param.filter->template data<int8_t>(),
ic,
oc,
trans_tmp_ptr);
}
free(trans_tmp_ptr);
}
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::PrepareForRun() {
auto& param = this->Param<param_t>();
w_scale_ = param.weight_scale;
if (w_scale_.size() != 1 && w_scale_.size() != param.filter->dims()[0]) {
LOG(FATAL) << "weights scale size must equal to filter size";
return;
}
if (w_scale_.size() == 1) {
for (int i = 0; i < param.filter->dims()[0] - 1; ++i) {
w_scale_.push_back(w_scale_[0]);
}
}
float input_scale = param.input_scale;
for (auto& ws : w_scale_) {
ws *= input_scale;
}
if (param.bias) {
bias_.Resize(param.bias->dims());
auto ptr = bias_.mutable_data<float>();
auto ptr_in = param.bias->template data<float>();
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] = ptr_in[i];
}
}
if (OutType == PRECISION(kInt8)) {
float output_scale = param.output_scale;
for (auto& ws : w_scale_) {
ws /= output_scale;
}
if (param.bias) {
auto ptr = bias_.mutable_data<float>();
for (int i = 0; i < bias_.numel(); ++i) {
ptr[i] /= output_scale;
}
}
}
ReInitWhenNeeded();
}
template <PrecisionType OutType>
void WinogradConv<PRECISION(kInt8), OutType>::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
ctx.ExtendWorkspace(workspace_size_);
const auto* i_data = param.x->template data<int8_t>();
const auto* w_data = weights_.data<int16_t>();
const auto* b_data = param.bias ? bias_.data<float>() : nullptr;
// const float* i_data;
auto x_dims = param.x->dims();
auto w_dims = param.filter->dims();
auto o_dims = param.output->dims();
int iw = x_dims[3]; // nchw
int ih = x_dims[2];
int ic = x_dims[1];
int bs = x_dims[0];
int oh = o_dims[2];
int ow = o_dims[3];
int oc = o_dims[1];
// now always choose small
if (OutType == PRECISION(kInt8)) {
auto* o_data = param.output->template mutable_data<int8_t>();
lite::arm::math::conv_compute_2x2_3x3_int8<int8_t>(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
w_scale_.data(),
param,
&ctx);
} else {
auto* o_data = param.output->template mutable_data<float>();
lite::arm::math::conv_compute_2x2_3x3_int8<float>(i_data,
o_data,
bs,
oc,
oh,
ow,
ic,
ih,
iw,
w_data,
b_data,
w_scale_.data(),
param,
&ctx);
}
}
template class WinogradConv<PRECISION(kInt8), PRECISION(kInt8)>;
template class WinogradConv<PRECISION(kInt8), PRECISION(kFloat)>;
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -16,11 +16,11 @@
#include <cmath>
#include <string>
#include <vector>
#include "lite/backends/arm/math/conv_impl.h"
#include "lite/core/context.h"
#include "lite/core/kernel.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace kernels {
......@@ -52,7 +52,27 @@ class WinogradConv : public KernelLite<TARGET(kARM), Ptype> {
bool choose_small_{false};
int wino_iw{8};
};
template <PrecisionType OutType>
class WinogradConv<PRECISION(kInt8), OutType>
: public KernelLite<TARGET(kARM), PRECISION(kInt8)> {
public:
WinogradConv() = default;
~WinogradConv() {}
virtual void PrepareForRun();
virtual void ReInitWhenNeeded();
virtual void Run();
protected:
using param_t = operators::ConvParam;
Tensor weights_;
Tensor bias_;
DDim last_shape_;
int workspace_size_{0};
int last_function_{-1};
bool choose_small_{true};
int wino_iw{4};
std::vector<float> w_scale_;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -34,7 +34,7 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_bool(basic_test, true, "do all tests");
DEFINE_bool(basic_test, false, "do all tests");
DEFINE_bool(check_result, true, "check the result");
DEFINE_int32(batch, 1, "batch size");
......@@ -614,6 +614,9 @@ TEST(TestConv3x3s1Int8, test_conv_3x3s1) {
dims.push_back(DDim({batch, cin, h, h}));
}
}
if (cin == 1 && cout == 1) {
continue;
}
test_conv_int8(dims,
weights_dim,
1,
......
......@@ -179,6 +179,141 @@ bool test_sgemm_c4(
#endif
return true;
}
bool test_sgemm_c8(
int m, int n, int k, bool has_bias, bool has_relu, int cls, int ths) {
int m_round = (m + 7) / 8 * 8;
int k_round = (k + 7) / 8 * 8;
int size_a = m * k;
int size_b = n * k;
int size_a_c4 = m_round * k_round;
int size_b_c8 = k_round * n;
Tensor ta;
Tensor tb;
Tensor ta_c4;
Tensor tb_c8;
Tensor tc;
Tensor tc_basic;
Tensor tc_backup;
Tensor tbias;
ta.Resize({size_a});
tb.Resize({size_b});
ta_c4.Resize({size_a_c4});
tb_c8.Resize({size_b_c8});
tc.Resize({m_round * n});
tc_basic.Resize({m_round * n});
tbias.Resize({m});
ta.set_precision(PRECISION(kInt16));
tb.set_precision(PRECISION(kInt16));
ta_c4.set_precision(PRECISION(kInt16));
tb_c8.set_precision(PRECISION(kInt16));
tc.set_precision(PRECISION(kInt32));
tc_basic.set_precision(PRECISION(kInt32));
tbias.set_precision(PRECISION(kInt32));
fill_tensor_rand(ta);
fill_tensor_rand(tb);
fill_tensor_rand(tbias);
fill_tensor_rand(tc);
auto da = ta.mutable_data<int16_t>();
auto db = tb.mutable_data<int16_t>();
auto da_c4 = ta_c4.mutable_data<int16_t>();
auto db_c8 = tb_c8.mutable_data<int16_t>();
auto dc_basic = tc_basic.mutable_data<int32_t>();
auto dbias = tbias.mutable_data<int32_t>();
// trans A, B to c4
basic_trans_mat_to_c8(da, da_c4, k, m, k, true);
basic_trans_mat_to_c8(db, db_c8, n, k, n, false);
LOG(INFO) << "sgemm_c8 M: " << m << ", N: " << n << ", K: " << k
<< ", relu: " << (has_relu ? "true" : "false")
<< ", bias: " << (has_bias ? "true" : "false");
if (FLAGS_check_result) {
basic_gemm_c8(false,
false,
m,
n,
k,
1,
da,
k,
db,
n,
0,
dc_basic,
n,
dbias,
false,
false);
}
Timer t0;
LOG(INFO) << "basic test end";
#ifdef LITE_WITH_ARM
//! compute
double ops = 2.0 * m_round * n * k_round;
std::unique_ptr<paddle::lite::KernelContext> ctx1(
new paddle::lite::KernelContext);
auto& ctx = ctx1->As<paddle::lite::ARMContext>();
ctx.SetRunMode(static_cast<paddle::lite_api::PowerMode>(cls), ths);
auto dc = tc.mutable_data<int32_t>();
for (int j = 0; j < FLAGS_warmup; ++j) {
paddle::lite::arm::math::sgemm_prepack_c8_int16_small(
m, n, k, da_c4, db_c8, dc, &ctx);
}
LOG(INFO) << "basic test end";
for (int i = 0; i < FLAGS_repeats; ++i) {
t0.Start();
paddle::lite::arm::math::sgemm_prepack_c8_int16_small(
m, n, k, da_c4, db_c8, dc, &ctx);
t0.Stop();
}
LOG(INFO) << "basic test end";
LOG(INFO) << "M: " << m << ", N: " << n << ", K: " << k
<< ", power_mode: " << cls << ", threads: " << ths
<< ", GOPS: " << ops * 1e-9f
<< " GOPS, avg time: " << t0.LapTimes().Avg()
<< " ms, min time: " << t0.LapTimes().Min()
<< " ms, mean GOPs: " << ops * 1e-6f / t0.LapTimes().Avg()
<< " GOPs, max GOPs: " << ops * 1e-6f / t0.LapTimes().Min()
<< " GOPs";
if (FLAGS_check_result) {
double max_ratio = 0;
double max_diff = 0;
tensor_cmp_host(tc_basic, tc, max_ratio, max_diff);
LOG(INFO) << "compare result, max diff: " << max_diff
<< ", max ratio: " << max_ratio;
if (std::abs(max_ratio) > 1e-4f && std::abs(max_diff) > 5e-5f) {
Tensor tdiff;
tdiff.set_precision(PRECISION(kInt32));
tdiff.Resize(tc.dims());
tensor_diff(tc_basic, tc, tdiff);
LOG(INFO) << "a: ";
print_tensor(ta);
LOG(INFO) << "a_c8: ";
print_tensor(ta_c4);
LOG(INFO) << "b: ";
print_tensor(tb);
LOG(INFO) << "b_c8: ";
print_tensor(tb_c8);
LOG(INFO) << "basic result: ";
print_tensor(tc_basic);
LOG(INFO) << "lite result: ";
print_tensor(tc);
LOG(INFO) << "diff result: ";
print_tensor(tdiff);
return false;
}
}
#endif
return true;
}
TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
if (FLAGS_basic_test) {
......@@ -186,11 +321,11 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
paddle::lite::DeviceInfo::Init();
#endif
LOG(INFO) << "run basic sgemm_c4 test";
for (auto& m : {1, 3, 8, 32, 397}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789}) {
for (auto& k : {1, 3, 8, 59, 234}) {
for (auto& has_bias : {false, true}) {
for (auto& has_relu : {false, true}) {
for (auto& m : {1, 3, 8, 32, 397, 32, 64, 77}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789, 1}) {
for (auto& k : {1, 3, 8, 59, 234, 19}) {
for (auto& has_bias : {false}) {
for (auto& has_relu : {false}) {
for (auto& th : {1, 2, 4}) {
auto flag = test_sgemm_c4(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
......@@ -213,8 +348,41 @@ TEST(TestSgemmC4, test_func_sgemm_c4_prepacked) {
}
}
}
TEST(TestSgemmC8, test_func_sgemm_c8_prepacked) {
if (FLAGS_basic_test) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
LOG(INFO) << "run basic sgemm_c4 test";
for (auto& m : {1, 3, 8, 32, 397, 32, 64, 77}) {
for (auto& n : {1, 2, 3, 4, 13, 141, 789, 1}) {
for (auto& k : {1, 3, 8, 59, 234, 19}) {
for (auto& has_bias : {false}) {
for (auto& has_relu : {false}) {
for (auto& th : {1}) {
auto flag = test_sgemm_c8(
m, n, k, has_bias, has_relu, FLAGS_power_mode, th);
if (flag) {
LOG(INFO) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " passed\n";
} else {
LOG(FATAL) << "test m = " << m << ", n=" << n << ", k=" << k
<< ", bias: " << (has_bias ? "true" : "false")
<< ", relu: " << (has_relu ? "true" : "false")
<< " failed\n";
}
}
}
}
}
}
}
}
}
TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) {
TEST(TestSgemmCnCustom, test_func_sgemm_cn_prepacked_custom) {
#ifdef LITE_WITH_ARM
paddle::lite::DeviceInfo::Init();
#endif
......@@ -230,6 +398,18 @@ TEST(TestSgemmC4Custom, test_func_sgemm_c4_prepacked_custom) {
<< ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
}
flag = test_sgemm_c8(FLAGS_M,
FLAGS_N,
FLAGS_K,
FLAGS_flag_bias,
FLAGS_flag_relu,
FLAGS_power_mode,
FLAGS_threads);
if (!flag) {
LOG(FATAL) << "test m = " << FLAGS_M << ", n=" << FLAGS_N
<< ", k=" << FLAGS_K << ", bias: " << FLAGS_flag_bias
<< ", relu: " << FLAGS_flag_relu << " failed!!";
}
LOG(INFO) << "test m = " << FLAGS_M << ", n=" << FLAGS_N << ", k=" << FLAGS_K
<< ", bias: " << FLAGS_flag_bias << ", relu: " << FLAGS_flag_relu
<< " passed!!";
......
......@@ -62,6 +62,72 @@ static void basic_trans_mat_to_c4(const type* input,
}
delete[] zero_buf;
}
template <typename type>
static void basic_trans_mat_to_c8(const type* input,
type* output,
const int ldin,
const int M,
const int K,
bool pack_k) {
const int m_round = (M + 7) / 8 * 8;
int k_round = (K + 7) / 8 * 8;
if (!pack_k) {
k_round = K;
}
const int m_loop = m_round / 8;
type zero_buf[K];
memset(zero_buf, 0, K * sizeof(type));
for (int i = 0; i < m_loop; ++i) {
const type* in0 = input + i * 8 * ldin;
const type* in1 = in0 + ldin;
const type* in2 = in1 + ldin;
const type* in3 = in2 + ldin;
const type* in4 = in3 + ldin;
const type* in5 = in4 + ldin;
const type* in6 = in5 + ldin;
const type* in7 = in6 + ldin;
if (8 * (i + 1) - M > 0) {
switch (8 * (i + 1) - M) {
case 7:
in1 = zero_buf;
case 6:
in2 = zero_buf;
case 5:
in3 = zero_buf;
case 4:
in4 = zero_buf;
case 3:
in5 = zero_buf;
case 2:
in6 = zero_buf;
case 1:
in7 = zero_buf;
default:
break;
}
}
for (int j = 0; j < K; ++j) {
*output++ = *in0++;
*output++ = *in1++;
*output++ = *in2++;
*output++ = *in3++;
*output++ = *in4++;
*output++ = *in5++;
*output++ = *in6++;
*output++ = *in7++;
}
for (int j = K; j < k_round; ++j) {
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
*output++ = static_cast<type>(0);
}
}
}
template <typename type, typename type2>
static void basic_gemm_c4(bool trans_a,
......@@ -118,6 +184,60 @@ static void basic_gemm_c4(bool trans_a,
free(tmp_c);
}
template <typename type, typename type2>
static void basic_gemm_c8(bool trans_a,
bool trans_b,
int m,
int n,
int k,
type2 alpha,
const type* a,
int lda,
const type* b,
int ldb,
type2 beta,
type2* c,
int ldc,
const type2* bias,
bool flag_bias = false,
bool flag_relu = false) {
type2* tmp_c = reinterpret_cast<type2*>(malloc(m * ldc * sizeof(type2)));
memset(tmp_c, 0, m * ldc * sizeof(type2));
#pragma omp parallel for
for (int i = 0; i < m; ++i) {
auto bias_data = static_cast<type2>(0);
if (flag_bias) {
bias_data = bias[i];
}
for (int j = 0; j < n; ++j) {
auto sum = static_cast<type2>(0);
for (int l = 0; l < k; ++l) {
type av;
type bv;
if (trans_a) {
av = a[l * lda + i];
} else {
av = a[i * lda + l];
}
if (trans_b) {
bv = b[j * ldb + l];
} else {
bv = b[l * ldb + j];
}
sum += av * bv;
}
type2 tmp = alpha * sum + beta * tmp_c[i * ldc + j] + bias_data;
if (flag_relu) {
tmp_c[i * ldc + j] = tmp > (type2)0 ? tmp : (type2)0;
} else {
tmp_c[i * ldc + j] = tmp;
}
}
}
//! trans c to c4
basic_trans_mat_to_c8(tmp_c, c, ldc, m, n, false);
free(tmp_c);
}
template <typename type, typename type2>
static void basic_gemm(bool trans_a,
bool trans_b,
......
......@@ -50,6 +50,10 @@ void fill_tensor_const(Tensor& tensor, float value) { // NOLINT
fill_tensor_host_const_impl(
tensor.mutable_data<int8_t>(), static_cast<signed char>(value), size);
break;
case PRECISION(kInt16):
fill_tensor_host_const_impl(
tensor.mutable_data<int16_t>(), static_cast<int16_t>(value), size);
break;
case PRECISION(kInt32):
fill_tensor_host_const_impl(
tensor.mutable_data<int>(), static_cast<int>(value), size);
......@@ -78,6 +82,12 @@ void fill_tensor_host_rand_impl<signed char>(signed char* dio, int64_t size) {
}
}
template <>
void fill_tensor_host_rand_impl<int16_t>(int16_t* dio, int64_t size) {
for (int64_t i = 0; i < size; ++i) {
dio[i] = (rand() % 256 - 128) * 2; // NOLINT
}
}
template <>
void fill_tensor_host_rand_impl<unsigned char>(unsigned char* dio,
int64_t size) {
for (int64_t i = 0; i < size; ++i) {
......@@ -95,6 +105,9 @@ void fill_tensor_rand(Tensor& tensor) { // NOLINT
case PRECISION(kInt8):
fill_tensor_host_rand_impl(tensor.mutable_data<int8_t>(), size);
break;
case PRECISION(kInt16):
fill_tensor_host_rand_impl(tensor.mutable_data<int16_t>(), size);
break;
case PRECISION(kInt32):
fill_tensor_host_rand_impl(tensor.mutable_data<int>(), size);
break;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册