提交 695ca28c 编写于 作者: J juncaipeng 提交者: GitHub

Support weight quantization (#2791)

* optimize quant_dequant_fuse_pass, test=develop

* update, test=develop

* update, test=develop

* fix bug for accessing the removed node, test=develop

* set the bias of int8 conv as float, test=develop

* support weight quantization, test=develop

* up, test=develop

* up, test=develop

* up, test=develop
上级 b7758556
......@@ -41,6 +41,8 @@ void LightPredictor::Build(const std::string& model_dir,
default:
LOG(FATAL) << "Unknown model type";
}
DequantizeWeight();
BuildRuntimeProgram(cpp_program_desc_);
PrepareFeedFetch();
}
......@@ -144,5 +146,69 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) {
program_->set_exec_scope(program.exec_scope());
}
void LightPredictor::DequantizeWeight() {
#define PROCESS_CONV2D_DATA() \
for (int64_t i = 0; i < h; ++i) { \
for (int64_t j = 0; j < w; ++j) { \
fp_data[i * w + j] = scale_list[i] * int_data[i * w + j]; \
} \
}
#define PROCESS_FC_DATA() \
for (int i = 0; i < input_tensor->numel(); i++) { \
*fp_data = scale_list[0] * (*int_data); \
++fp_data; \
++int_data; \
}
Tensor tmp_tensor;
CHECK(cpp_program_desc_.BlocksSize());
auto* main_block = cpp_program_desc_.GetBlock<cpp::BlockDesc>(0);
for (size_t k = 0; k < main_block->OpsSize(); ++k) {
auto* op_desc = main_block->GetOp<cpp::OpDesc>(k);
if (op_desc->HasAttr("quantize_weight_bits")) { // weight quantized op
auto input_names = op_desc->input_vars();
for (auto& input_name : input_names) {
std::string input_scale_name = input_name + "_quant_scale";
if (op_desc->HasAttr(input_scale_name)) { // the input is quantized
auto input_tensor =
scope_->FindVar(input_name)->GetMutable<lite::Tensor>();
tmp_tensor.CopyDataFrom(*input_tensor);
auto scale_list =
op_desc->GetAttr<std::vector<float>>(input_scale_name);
int quantize_weight_bits =
op_desc->GetAttr<int>("quantize_weight_bits");
float* fp_data = input_tensor->mutable_data<float>();
std::string op_type = op_desc->Type();
if (op_type == "conv2d" || op_type == "depthwise_conv2d") {
int64_t h = input_tensor->dims()[0];
int64_t w = input_tensor->numel() / h;
CHECK_EQ(scale_list.size(), h);
if (quantize_weight_bits == 8) {
const int8_t* int_data = tmp_tensor.data<int8_t>();
PROCESS_CONV2D_DATA()
} else {
const int16_t* int_data = tmp_tensor.data<int16_t>();
PROCESS_CONV2D_DATA()
}
} else if (op_type == "fc" || op_type == "mul") {
if (quantize_weight_bits == 8) {
const int8_t* int_data = tmp_tensor.data<int8_t>();
PROCESS_FC_DATA()
} else {
const int16_t* int_data = tmp_tensor.data<int16_t>();
PROCESS_FC_DATA()
}
}
}
}
}
}
#undef PROCESS_CONV2D_DATA
#undef PROCESS_FC_DATA
}
} // namespace lite
} // namespace paddle
......@@ -78,6 +78,8 @@ class LITE_API LightPredictor {
void BuildRuntimeProgram(const cpp::ProgramDesc& prog);
void DequantizeWeight();
private:
std::shared_ptr<Scope> scope_;
std::unique_ptr<RuntimeProgram> program_;
......
......@@ -44,3 +44,4 @@ USE_MIR_PASS(memory_optimize_pass);
USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
USE_MIR_PASS(npu_subgraph_pass);
USE_MIR_PASS(xpu_subgraph_pass);
USE_MIR_PASS(weight_quantization_preprocess_pass);
......@@ -35,6 +35,7 @@ lite_cc_library(mir_passes
demo_pass.cc
runtime_context_assign_pass.cc
memory_optimize_pass.cc
weight_quantization_preprocess_pass.cc
DEPS mir_pass types context ${mir_fusers} ${mir_subgraphs})
# lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
......
......@@ -100,14 +100,17 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto eps = matched.at("bn")->stmt()->op_info()->GetAttr<float>("epsilon");
// conv
auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name)
->GetMutable<lite::Tensor>();
std::string conv_weight_name = matched.at("conv_weight")->arg()->name;
auto conv_weight_t =
scope->FindVar(conv_weight_name)->GetMutable<lite::Tensor>();
CHECK_EQ(static_cast<size_t>(bn_scale_t->data_size()),
static_cast<size_t>(conv_weight_t->dims()[0]))
<< "The BN bias's size should be equal to the size of the first "
<< "dim size of the conv weights";
size_t weight_num = conv_weight_t->data_size();
bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
bool is_weight_quantization =
conv_op_desc->HasAttr("quantize_weight_bits") ? true : false;
// comupte BN alpha and beta
Tensor alpha_tensor, beta_tensor;
......@@ -160,6 +163,16 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
}
}
conv_op_desc->SetAttr("weight_scale", weight_scale);
} else if (is_weight_quantization) {
std::string scale_name = conv_weight_name + "_quant_scale";
if (conv_op_desc->HasAttr(scale_name)) {
auto scale = conv_op_desc->GetAttr<std::vector<float>>(scale_name);
CHECK_EQ(scale.size(), alpha_tensor.numel());
for (size_t i = 0; i < scale.size(); i++) {
scale[i] *= alpha_data[i];
}
conv_op_desc->SetAttr(scale_name, scale);
}
} else {
// compute new conv_weight
auto conv_weight_d = conv_weight_t->mutable_data<float>();
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/weight_quantization_preprocess_pass.h"
#include <memory>
#include <string>
#include <vector>
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void WeightQuantizationPreprocessPass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> weight_quantized_op = {"conv2d", "depthwise_conv2d"};
for (auto& node : graph->StmtTopologicalOrder()) {
if (node->IsStmt() &&
std::find(weight_quantized_op.begin(),
weight_quantized_op.end(),
node->AsStmt().op_type()) != weight_quantized_op.end()) {
auto* scope = node->stmt()->op()->scope();
auto* op_desc = node->stmt()->mutable_op_info();
if (op_desc->HasAttr("quantize_weight_bits")) {
for (auto& input_name : op_desc->input_vars()) {
std::string scale_name = input_name + "_quant_scale";
if (op_desc->HasAttr(scale_name)) {
VLOG(5) << "op:" << op_desc->Type() << " input_name:" << input_name;
auto input_tensor =
scope->FindVar(input_name)->GetMutable<lite::Tensor>();
int weight_out_channel = static_cast<int>(input_tensor->dims()[0]);
auto input_scale = op_desc->GetAttr<std::vector<float>>(scale_name);
// scale length is equal to weight out channel
std::vector<float> scale_list(weight_out_channel, input_scale[0]);
op_desc->SetAttr(scale_name, scale_list);
}
}
}
}
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(weight_quantization_preprocess_pass,
paddle::lite::mir::WeightQuantizationPreprocessPass)
.BindTargets({TARGET(kAny)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include "lite/core/mir/pass.h"
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* If the model is quantized by WeightQuantization in PostTrainingQuantization,
* the data type of the weight in quantized ops (conv2d, depthwise_conv2d) is
* int, and the scale is save in the quantized ops.
* WeightQuantizationPreprocessPass obtains the scale value, expands the
* scale value to a list, and save the list in the quantized ops.
*/
class WeightQuantizationPreprocessPass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -55,10 +55,11 @@ class Optimizer {
if (passes.empty()) {
std::vector<std::string> passes_local{
{"lite_quant_dequant_fuse_pass", //
"lite_conv_elementwise_fuse_pass", // conv-elemwise-bn
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_fuse_pass", // conv-bn-elemwise
{"lite_quant_dequant_fuse_pass", //
"weight_quantization_preprocess_pass", //
"lite_conv_elementwise_fuse_pass", // conv-elemwise-bn
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_fuse_pass", // conv-bn-elemwise
// TODO(Superjomn) Refine the fusion related design to select fusion
// kernels for devices automatically.
"lite_conv_activation_fuse_pass", //
......
......@@ -45,6 +45,7 @@ int SizeOfType(framework::proto::VarType::Type type) {
DO(FP16, float);
DO(FP32, float);
DO(INT8, int8_t);
DO(INT16, int16_t);
DO(INT32, int);
DO(INT64, int64_t);
#undef DO
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册