未验证 提交 9a2a720c 编写于 作者: H huzhiqiang 提交者: GitHub

[Arm] Update Fc arm kernel implementation to reduce memory usage (#3949)

上级 da130862
...@@ -88,7 +88,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() { ...@@ -88,7 +88,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto i_data = param.input->data<float>(); auto i_data = param.input->data<float>();
auto o_data = param.output->mutable_data<float>(); auto o_data = param.output->mutable_data<float>();
auto w_data = flag_gemm_ ? param.w->data<float>() : weights_.data<float>(); auto w_data = param.w->data<float>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr; const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) { if (flag_trans_bias_) {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
...@@ -149,8 +149,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() { ...@@ -149,8 +149,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto i_data = param.input->data<int8_t>(); auto i_data = param.input->data<int8_t>();
auto o_data = param.output->mutable_data<float>(); auto o_data = param.output->mutable_data<float>();
auto w_data = auto w_data = param.w->data<int8_t>();
flag_trans_weights_ ? weights_.data<int8_t>() : param.w->data<int8_t>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr; const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) { if (flag_trans_bias_) {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
...@@ -208,8 +207,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() { ...@@ -208,8 +207,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto i_data = param.input->data<int8_t>(); auto i_data = param.input->data<int8_t>();
auto o_data = param.output->mutable_data<int8_t>(); auto o_data = param.output->mutable_data<int8_t>();
auto w_data = auto w_data = param.w->data<int8_t>();
flag_trans_weights_ ? weights_.data<int8_t>() : param.w->data<int8_t>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr; const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) { if (flag_trans_bias_) {
b_data = bias_.data<float>(); b_data = bias_.data<float>();
......
...@@ -104,9 +104,11 @@ class FcCompute : public KernelLite<TARGET(kARM), PType> { ...@@ -104,9 +104,11 @@ class FcCompute : public KernelLite<TARGET(kARM), PType> {
CHECK_EQ(k_, static_cast<int>(w_dims[0])); CHECK_EQ(k_, static_cast<int>(w_dims[0]));
flag_gemm_ = check_fc_use_gemm<PType, OutType>( flag_gemm_ = check_fc_use_gemm<PType, OutType>(
m_, param.weight_scale, param.bias != nullptr); m_, param.weight_scale, param.bias != nullptr);
if (!flag_trans_weights_ && !flag_gemm_) { if (flag_trans_weights_ == flag_gemm_) {
flag_trans_weights_ = true; flag_trans_weights_ = !flag_trans_weights_;
fc_trans_weights<PType>(*param.w, &weights_); Tensor tmp_tensor;
fc_trans_weights<PType>(*param.w, &tmp_tensor);
param.w->CopyDataFrom(tmp_tensor);
} }
} }
...@@ -117,7 +119,6 @@ class FcCompute : public KernelLite<TARGET(kARM), PType> { ...@@ -117,7 +119,6 @@ class FcCompute : public KernelLite<TARGET(kARM), PType> {
private: private:
DDim last_shape_; DDim last_shape_;
Tensor weights_;
Tensor bias_; Tensor bias_;
bool flag_trans_weights_{false}; bool flag_trans_weights_{false};
bool flag_trans_bias_{false}; bool flag_trans_bias_{false};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册