diff --git a/lite/api/light_api.cc b/lite/api/light_api.cc index a0c4b7e5e375d9d004de63345ba5013ee6c252b9..1558e286178b461dc04c4366dc3adca81b2dd9de 100644 --- a/lite/api/light_api.cc +++ b/lite/api/light_api.cc @@ -41,6 +41,8 @@ void LightPredictor::Build(const std::string& model_dir, default: LOG(FATAL) << "Unknown model type"; } + + DequantizeWeight(); BuildRuntimeProgram(cpp_program_desc_); PrepareFeedFetch(); } @@ -144,5 +146,69 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) { program_->set_exec_scope(program.exec_scope()); } +void LightPredictor::DequantizeWeight() { +#define PROCESS_CONV2D_DATA() \ + for (int64_t i = 0; i < h; ++i) { \ + for (int64_t j = 0; j < w; ++j) { \ + fp_data[i * w + j] = scale_list[i] * int_data[i * w + j]; \ + } \ + } + +#define PROCESS_FC_DATA() \ + for (int i = 0; i < input_tensor->numel(); i++) { \ + *fp_data = scale_list[0] * (*int_data); \ + ++fp_data; \ + ++int_data; \ + } + + Tensor tmp_tensor; + CHECK(cpp_program_desc_.BlocksSize()); + auto* main_block = cpp_program_desc_.GetBlock(0); + for (size_t k = 0; k < main_block->OpsSize(); ++k) { + auto* op_desc = main_block->GetOp(k); + if (op_desc->HasAttr("quantize_weight_bits")) { // weight quantized op + auto input_names = op_desc->input_vars(); + for (auto& input_name : input_names) { + std::string input_scale_name = input_name + "_quant_scale"; + if (op_desc->HasAttr(input_scale_name)) { // the input is quantized + auto input_tensor = + scope_->FindVar(input_name)->GetMutable(); + tmp_tensor.CopyDataFrom(*input_tensor); + auto scale_list = + op_desc->GetAttr>(input_scale_name); + int quantize_weight_bits = + op_desc->GetAttr("quantize_weight_bits"); + float* fp_data = input_tensor->mutable_data(); + + std::string op_type = op_desc->Type(); + if (op_type == "conv2d" || op_type == "depthwise_conv2d") { + int64_t h = input_tensor->dims()[0]; + int64_t w = input_tensor->numel() / h; + CHECK_EQ(scale_list.size(), h); + if (quantize_weight_bits == 8) { + const int8_t* int_data = tmp_tensor.data(); + PROCESS_CONV2D_DATA() + } else { + const int16_t* int_data = tmp_tensor.data(); + PROCESS_CONV2D_DATA() + } + } else if (op_type == "fc" || op_type == "mul") { + if (quantize_weight_bits == 8) { + const int8_t* int_data = tmp_tensor.data(); + PROCESS_FC_DATA() + } else { + const int16_t* int_data = tmp_tensor.data(); + PROCESS_FC_DATA() + } + } + } + } + } + } + +#undef PROCESS_CONV2D_DATA +#undef PROCESS_FC_DATA +} + } // namespace lite } // namespace paddle diff --git a/lite/api/light_api.h b/lite/api/light_api.h index 3781bc4d674db5d2e8794edaf33f00627b9977bb..d1789a9c98333f6e927ba470717d9227729f2108 100644 --- a/lite/api/light_api.h +++ b/lite/api/light_api.h @@ -78,6 +78,8 @@ class LITE_API LightPredictor { void BuildRuntimeProgram(const cpp::ProgramDesc& prog); + void DequantizeWeight(); + private: std::shared_ptr scope_; std::unique_ptr program_; diff --git a/lite/api/paddle_use_passes.h b/lite/api/paddle_use_passes.h index f5b7ea4d9f43b2a8802cd86da98bb8e95197d896..943760d30742b74a0fe9150e4c2d8c8bb5dbc52a 100644 --- a/lite/api/paddle_use_passes.h +++ b/lite/api/paddle_use_passes.h @@ -44,3 +44,4 @@ USE_MIR_PASS(memory_optimize_pass); USE_MIR_PASS(elementwise_mul_constant_eliminate_pass) USE_MIR_PASS(npu_subgraph_pass); USE_MIR_PASS(xpu_subgraph_pass); +USE_MIR_PASS(weight_quantization_preprocess_pass); diff --git a/lite/core/mir/CMakeLists.txt b/lite/core/mir/CMakeLists.txt index a32e0295dbfc2b3e635472649b437b64f1e93145..379ef67f2996519d0c8007d8f191efbd2166a9e3 100644 --- a/lite/core/mir/CMakeLists.txt +++ b/lite/core/mir/CMakeLists.txt @@ -35,6 +35,7 @@ lite_cc_library(mir_passes demo_pass.cc runtime_context_assign_pass.cc memory_optimize_pass.cc + weight_quantization_preprocess_pass.cc DEPS mir_pass types context ${mir_fusers} ${mir_subgraphs}) # lite_cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS diff --git a/lite/core/mir/fusion/conv_bn_fuser.cc b/lite/core/mir/fusion/conv_bn_fuser.cc index ec07278eed1f259c45e225497f94d682b544c57c..0f5bb64e10dd61c3edf4ddd32569a2d365651cdf 100644 --- a/lite/core/mir/fusion/conv_bn_fuser.cc +++ b/lite/core/mir/fusion/conv_bn_fuser.cc @@ -100,14 +100,17 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { auto eps = matched.at("bn")->stmt()->op_info()->GetAttr("epsilon"); // conv - auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name) - ->GetMutable(); + std::string conv_weight_name = matched.at("conv_weight")->arg()->name; + auto conv_weight_t = + scope->FindVar(conv_weight_name)->GetMutable(); CHECK_EQ(static_cast(bn_scale_t->data_size()), static_cast(conv_weight_t->dims()[0])) << "The BN bias's size should be equal to the size of the first " << "dim size of the conv weights"; size_t weight_num = conv_weight_t->data_size(); bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false; + bool is_weight_quantization = + conv_op_desc->HasAttr("quantize_weight_bits") ? true : false; // comupte BN alpha and beta Tensor alpha_tensor, beta_tensor; @@ -160,6 +163,16 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { } } conv_op_desc->SetAttr("weight_scale", weight_scale); + } else if (is_weight_quantization) { + std::string scale_name = conv_weight_name + "_quant_scale"; + if (conv_op_desc->HasAttr(scale_name)) { + auto scale = conv_op_desc->GetAttr>(scale_name); + CHECK_EQ(scale.size(), alpha_tensor.numel()); + for (size_t i = 0; i < scale.size(); i++) { + scale[i] *= alpha_data[i]; + } + conv_op_desc->SetAttr(scale_name, scale); + } } else { // compute new conv_weight auto conv_weight_d = conv_weight_t->mutable_data(); diff --git a/lite/core/mir/weight_quantization_preprocess_pass.cc b/lite/core/mir/weight_quantization_preprocess_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..c7889a54903f2a1d194fb3eade0bd92670b36699 --- /dev/null +++ b/lite/core/mir/weight_quantization_preprocess_pass.cc @@ -0,0 +1,60 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/mir/weight_quantization_preprocess_pass.h" +#include +#include +#include +#include "lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void WeightQuantizationPreprocessPass::Apply( + const std::unique_ptr& graph) { + std::vector weight_quantized_op = {"conv2d", "depthwise_conv2d"}; + for (auto& node : graph->StmtTopologicalOrder()) { + if (node->IsStmt() && + std::find(weight_quantized_op.begin(), + weight_quantized_op.end(), + node->AsStmt().op_type()) != weight_quantized_op.end()) { + auto* scope = node->stmt()->op()->scope(); + auto* op_desc = node->stmt()->mutable_op_info(); + if (op_desc->HasAttr("quantize_weight_bits")) { + for (auto& input_name : op_desc->input_vars()) { + std::string scale_name = input_name + "_quant_scale"; + if (op_desc->HasAttr(scale_name)) { + VLOG(5) << "op:" << op_desc->Type() << " input_name:" << input_name; + auto input_tensor = + scope->FindVar(input_name)->GetMutable(); + int weight_out_channel = static_cast(input_tensor->dims()[0]); + auto input_scale = op_desc->GetAttr>(scale_name); + // scale length is equal to weight out channel + std::vector scale_list(weight_out_channel, input_scale[0]); + op_desc->SetAttr(scale_name, scale_list); + } + } + } + } + } +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(weight_quantization_preprocess_pass, + paddle::lite::mir::WeightQuantizationPreprocessPass) + .BindTargets({TARGET(kAny)}); diff --git a/lite/core/mir/weight_quantization_preprocess_pass.h b/lite/core/mir/weight_quantization_preprocess_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..76a35c6b443c692ec08688abd4c10680be62b8af --- /dev/null +++ b/lite/core/mir/weight_quantization_preprocess_pass.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include "lite/core/mir/pass.h" +#include "lite/core/op_registry.h" +#include "lite/core/target_wrapper.h" + +namespace paddle { +namespace lite { +namespace mir { +/* + * If the model is quantized by WeightQuantization in PostTrainingQuantization, + * the data type of the weight in quantized ops (conv2d, depthwise_conv2d) is + * int, and the scale is save in the quantized ops. + * WeightQuantizationPreprocessPass obtains the scale value, expands the + * scale value to a list, and save the list in the quantized ops. + */ +class WeightQuantizationPreprocessPass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/lite/core/optimizer.h b/lite/core/optimizer.h index 00e9e07749901442f949fe885cdcfd358f822cba..ddd94484ac4bb8d96d5c55300c985d21b44f1843 100644 --- a/lite/core/optimizer.h +++ b/lite/core/optimizer.h @@ -55,10 +55,11 @@ class Optimizer { if (passes.empty()) { std::vector passes_local{ - {"lite_quant_dequant_fuse_pass", // - "lite_conv_elementwise_fuse_pass", // conv-elemwise-bn - "lite_conv_bn_fuse_pass", // - "lite_conv_elementwise_fuse_pass", // conv-bn-elemwise + {"lite_quant_dequant_fuse_pass", // + "weight_quantization_preprocess_pass", // + "lite_conv_elementwise_fuse_pass", // conv-elemwise-bn + "lite_conv_bn_fuse_pass", // + "lite_conv_elementwise_fuse_pass", // conv-bn-elemwise // TODO(Superjomn) Refine the fusion related design to select fusion // kernels for devices automatically. "lite_conv_activation_fuse_pass", // diff --git a/lite/model_parser/model_parser.cc b/lite/model_parser/model_parser.cc index ed3f45c598e74a0450454c15ad0cd9ad09266f8e..0dcb8e1eeab4b07d533a1bfc57cb8d9ca38b4d82 100644 --- a/lite/model_parser/model_parser.cc +++ b/lite/model_parser/model_parser.cc @@ -45,6 +45,7 @@ int SizeOfType(framework::proto::VarType::Type type) { DO(FP16, float); DO(FP32, float); DO(INT8, int8_t); + DO(INT16, int16_t); DO(INT32, int); DO(INT64, int64_t); #undef DO