未验证 提交 9a2a720c 编写于 作者: H huzhiqiang 提交者: GitHub

[Arm] Update Fc arm kernel implementation to reduce memory usage (#3949)

上级 da130862
......@@ -88,7 +88,7 @@ void FcCompute<PRECISION(kFloat), PRECISION(kFloat)>::Run() {
auto i_data = param.input->data<float>();
auto o_data = param.output->mutable_data<float>();
auto w_data = flag_gemm_ ? param.w->data<float>() : weights_.data<float>();
auto w_data = param.w->data<float>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
......@@ -149,8 +149,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kFloat)>::Run() {
auto i_data = param.input->data<int8_t>();
auto o_data = param.output->mutable_data<float>();
auto w_data =
flag_trans_weights_ ? weights_.data<int8_t>() : param.w->data<int8_t>();
auto w_data = param.w->data<int8_t>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
......@@ -208,8 +207,7 @@ void FcCompute<PRECISION(kInt8), PRECISION(kInt8)>::Run() {
auto i_data = param.input->data<int8_t>();
auto o_data = param.output->mutable_data<int8_t>();
auto w_data =
flag_trans_weights_ ? weights_.data<int8_t>() : param.w->data<int8_t>();
auto w_data = param.w->data<int8_t>();
const float* b_data = param.bias ? param.bias->data<float>() : nullptr;
if (flag_trans_bias_) {
b_data = bias_.data<float>();
......
......@@ -104,9 +104,11 @@ class FcCompute : public KernelLite<TARGET(kARM), PType> {
CHECK_EQ(k_, static_cast<int>(w_dims[0]));
flag_gemm_ = check_fc_use_gemm<PType, OutType>(
m_, param.weight_scale, param.bias != nullptr);
if (!flag_trans_weights_ && !flag_gemm_) {
flag_trans_weights_ = true;
fc_trans_weights<PType>(*param.w, &weights_);
if (flag_trans_weights_ == flag_gemm_) {
flag_trans_weights_ = !flag_trans_weights_;
Tensor tmp_tensor;
fc_trans_weights<PType>(*param.w, &tmp_tensor);
param.w->CopyDataFrom(tmp_tensor);
}
}
......@@ -117,7 +119,6 @@ class FcCompute : public KernelLite<TARGET(kARM), PType> {
private:
DDim last_shape_;
Tensor weights_;
Tensor bias_;
bool flag_trans_weights_{false};
bool flag_trans_bias_{false};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册