未验证 提交 8e2d4d30 编写于 作者: B baoachun 提交者: GitHub

add mkldnn int8 pass [step3] (#41599)

* add mkldnn int8 pass [step3]

* Add test for compute_propagate_scales_mkldnn_pass

* update pass

* update api comment and python api
Co-authored-by: Nwozna <joanna.wozna@intel.com>
上级 c7623d72
......@@ -218,6 +218,7 @@ endif()
cc_test(test_scale_matmul_fuse_pass SRCS mkldnn/scale_matmul_fuse_pass_tester.cc DEPS scale_matmul_fuse_pass)
cc_test(test_mkldnn_placement_pass SRCS mkldnn/mkldnn_placement_pass_tester.cc DEPS mkldnn_placement_pass)
cc_test(test_mkldnn_inplace_pass SRCS mkldnn/mkldnn_inplace_pass_tester.cc DEPS mkldnn_inplace_pass)
cc_test(test_compute_propagate_scales_mkldnn_pass SRCS mkldnn/compute_propagate_scales_mkldnn_pass_tester.cc DEPS compute_propagate_scales_mkldnn_pass naive_executor)
cc_test(test_cpu_quantize_placement_pass SRCS mkldnn/cpu_quantize_placement_pass_tester.cc DEPS cpu_quantize_placement_pass)
cc_test(test_cpu_quantize_pass SRCS mkldnn/cpu_quantize_pass_tester.cc DEPS cpu_quantize_pass naive_executor)
cc_test(test_cpu_quantize_squash_pass SRCS mkldnn/cpu_quantize_squash_pass_tester.cc DEPS cpu_quantize_squash_pass naive_executor)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/compute_propagate_scales_mkldnn_pass.h"
#include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace ir {
const std::array<float, 10> positive_and_negative_values = {
-0.0482659, -0.0102493, -0.00794221, -0.00387115, -0.00674586,
-0.0495346, 0.0629528, -0.00531285, -0.0230353, 0.0269089};
const std::vector<std::vector<float>> wx = {
{0.04347931, -0.5643393, 0.7551297, 0.26713502, 0.8055306, 0.91144973},
{0.01707571, 0.12741385, 0.15419468, 0.66127586, 0.46821925, 0.9665961},
{0.40393898, 0.884427, -0.5853097, 0.5840954, 0.9170512, 0.98245513}};
const std::vector<std::vector<float>> wh = {
{0.42484227, -0.9025513, 0.17087583, 0.8403284, 0.03325734, 0.92331886},
{0.32630175, 0.41691914, 0.99848574, 0.3504407, 0.06707559, 0.62239844}};
const std::vector<double> gru_scales = {2.35381475, 1.08304947, 1.32427582,
1.19001095, 1.00151656, 1.01785819};
const std::vector<double> lstm_scales = {2.35381475, 1.10797026, 1.00151656,
1.19001095, 1.09045166, 1.01785819};
static const std::initializer_list<std::string> conv_variable_names{
"conv_in", "filter", "bias", "conv_out"};
static const std::initializer_list<std::string> rnn_variable_names{
"x", "wx", "wh", "b", "h", "c"};
class ComputePropagateScalesMkldnnPassTest : public testing::Test {
public:
ComputePropagateScalesMkldnnPassTest() {
pass.reset(new ComputePropagateScalesMkldnnPass());
}
std::vector<float> GetScales(Tensor* tensor, int axis) const {
return pass->GetScales(tensor, axis);
}
void ComputeVarScales(ir::Graph* graph, Scope* scope,
const std::unordered_set<std::string> ops,
const std::string& weight_name, const int axis,
StringPairMap* var_quant_scales) const {
pass->ComputeVarScales(graph, scope, ops, weight_name, axis,
var_quant_scales);
}
void ComputeGruWeightScales(ir::Graph* graph, Scope* scope,
const std::string& wx_name,
const std::string& wh_name,
StringPairMap* var_quant_scales) const {
pass->ComputeGruWeightScales(graph, scope, wx_name, wh_name,
var_quant_scales);
}
void ComputeLstmWeightScales(ir::Graph* graph, Scope* scope,
std::string wx_name, std::string wh_name,
StringPairMap* var_quant_scales) const {
pass->ComputeLstmWeightScales(graph, scope, wx_name, wh_name,
var_quant_scales);
}
void InitTensorHolder(Scope* scope, const paddle::platform::Place& place,
const std::string& var_name) {
auto x = scope->Var(var_name);
auto tensor = x->GetMutable<LoDTensor>();
auto tensor_size = 1;
if (var_name == "filter") {
tensor_size = positive_and_negative_values.size();
} else if (var_name == "wx") {
tensor_size = wx.size();
} else if (var_name == "wh") {
tensor_size = wh.size();
}
tensor->mutable_data(place,
framework::TransToPhiDataType(proto::VarType::FP32),
tensor_size);
}
void PrepareGraph(ir::Graph* graph, const ProgramDesc& prog, Scope* scope,
const std::initializer_list<std::string>& variable_names) {
auto place = paddle::platform::CPUPlace();
NaiveExecutor exe{place};
exe.CreateVariables(prog, 0, true, scope);
for (auto& v : variable_names) {
InitTensorHolder(scope, place, v.c_str());
}
graph->SetNotOwned(kParamScopeAttr, scope);
}
void ComputeRnnWeightScalesTest(const std::string& type,
const std::initializer_list<std::string>& ops,
const framework::ProgramDesc& prog,
std::vector<double> scales) {
ir::Graph* graph(new ir::Graph(prog));
Scope scope;
PrepareGraph(graph, prog, &scope, rnn_variable_names);
std::string wx_name = "WeightX";
std::string wh_name = "WeightH";
std::string wx_var_names = "wx";
std::string wh_var_names = "wh";
StringPairMap var_quant_scales;
auto* wx_var = scope.FindVar(wx_var_names);
auto* wx_tensor = wx_var->GetMutable<LoDTensor>();
wx_tensor->Resize(phi::make_dim(wx.size(), wx[0].size()));
for (size_t i = 0; i < wx.size(); i++)
std::copy(begin(wx[i]), end(wx[i]),
wx_tensor->mutable_data<float>(platform::CPUPlace()) +
i * wx[0].size());
auto* wh_var = scope.FindVar(wh_var_names);
auto* wh_tensor = wh_var->GetMutable<LoDTensor>();
wh_tensor->Resize(phi::make_dim(wh.size(), wh[0].size()));
for (size_t i = 0; i < wh.size(); i++)
std::copy(begin(wh[i]), end(wh[i]),
wh_tensor->mutable_data<float>(platform::CPUPlace()) +
i * wh[0].size());
if (type == "gru") {
ComputeGruWeightScales(graph, &scope, wx_name, wh_name,
&var_quant_scales);
} else {
ComputeLstmWeightScales(graph, &scope, wx_name, wh_name,
&var_quant_scales);
}
bool is_unsigned;
framework::Tensor wx_result_tensor;
std::tie(is_unsigned, wx_result_tensor) = var_quant_scales[wx_var_names];
ASSERT_EQ(is_unsigned, false);
ASSERT_EQ(wx_result_tensor.numel(), static_cast<int64_t>(scales.size()));
for (int64_t i = 0; i < wx_result_tensor.numel(); i++) {
ASSERT_FLOAT_EQ(wx_result_tensor.data<float>()[i], scales[i]);
}
}
private:
std::unique_ptr<ComputePropagateScalesMkldnnPass> pass;
};
void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs) {
auto* op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetAttr("use_mkldnn", true);
op->SetAttr("name", name);
if (type == "conv2d") {
op->SetInput("Input", {inputs[0]});
if (inputs.size() > 1) op->SetInput("Filter", {inputs[1]});
if (inputs.size() > 2) op->SetInput("Bias", {inputs[2]});
op->SetOutput("Output", {outputs[0]});
} else if (type == "fusion_gru" || type == "fusion_lstm") {
op->SetInput("X", {inputs[0]});
op->SetInput("WeightX", {inputs[1]});
op->SetInput("WeightH", {inputs[2]});
op->SetOutput("Hidden", {outputs[0]});
if (type == "fusion_lstm") op->SetOutput("Cell", {outputs[1]});
}
}
ProgramDesc BuildConv2dProgramDesc() {
ProgramDesc prog;
for (auto& v : conv_variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "conv2d", "Conv2d", {"conv_in", "filter", "bias"}, {"conv_out"});
return prog;
}
ProgramDesc BuildFusionGruProgramDesc() {
ProgramDesc prog;
for (auto& v : rnn_variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "fusion_gru", "Fusion_gru", {"x", "wx", "wh"}, {"h"});
return prog;
}
ProgramDesc BuildFusionLstmProgramDesc() {
ProgramDesc prog;
for (auto& v : rnn_variable_names) {
prog.MutableBlock(0)->Var(v);
}
SetOp(&prog, "fusion_lstm", "Fusion_lstm", {"x", "wx", "wh"}, {"h", "c"});
return prog;
}
TEST_F(ComputePropagateScalesMkldnnPassTest, get_scales_function) {
const auto& values = positive_and_negative_values;
float max_val = *std::max_element(values.begin(), values.end());
framework::Tensor var_tensor;
var_tensor.Resize(phi::make_dim(values.size(), 1));
std::copy(begin(values), end(values),
var_tensor.mutable_data<float>(platform::CPUPlace()));
std::vector<float> results = GetScales(&var_tensor, 0);
ASSERT_EQ(results.size(), std::size_t(1));
ASSERT_EQ(results[0], (1.f / max_val));
}
TEST_F(ComputePropagateScalesMkldnnPassTest, compute_var_scales) {
auto prog = BuildConv2dProgramDesc();
const auto& values = positive_and_negative_values;
ir::Graph* graph(new ir::Graph(prog));
Scope scope;
PrepareGraph(graph, prog, &scope, conv_variable_names);
std::initializer_list<std::string> ops = {"conv2d", "depthwise_conv2d"};
std::string weight_name = "Filter";
std::string weight_var_name = "filter";
auto axis = 1;
StringPairMap var_quant_scales;
auto* var = scope.FindVar(weight_var_name);
auto* weight_tensor = var->GetMutable<LoDTensor>();
weight_tensor->Resize(phi::make_dim(1, values.size()));
std::copy(begin(values), end(values),
weight_tensor->mutable_data<float>(platform::CPUPlace()));
auto max_val = *std::max_element(values.begin(), values.end());
ComputeVarScales(graph, &scope, ops, weight_name, axis, &var_quant_scales);
bool is_unsigned;
framework::Tensor result_tensor;
std::tie(is_unsigned, result_tensor) = var_quant_scales[weight_var_name];
ASSERT_EQ(is_unsigned, false);
ASSERT_EQ(result_tensor.numel(), 1);
ASSERT_FLOAT_EQ(result_tensor.data<float>()[0], (1.0 / max_val));
}
TEST_F(ComputePropagateScalesMkldnnPassTest, compute_gru_weight_scales) {
ComputeRnnWeightScalesTest("gru", {"fusion_gru", "multi_gru"},
BuildFusionGruProgramDesc(), gru_scales);
}
TEST_F(ComputePropagateScalesMkldnnPassTest, compute_lstm_weight_scales) {
ComputeRnnWeightScalesTest("lstm", {"fusion_lstm"},
BuildFusionLstmProgramDesc(), lstm_scales);
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"
#include <sstream>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/ir/mkldnn/cpu_quantize_pass.h"
#include "paddle/fluid/framework/ir/mkldnn/mkldnn_pass_util.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/string/pretty_log.h"
......@@ -226,12 +226,21 @@ void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output,
bool CPUQuantizePass::AreScalesPresentForVarNames(
std::vector<std::string> names) const {
auto& scales = Get<VarQuantScale>("quant_var_scales");
bool present = true;
for (auto name : names) {
if (scales.find(name) == scales.end()) {
present = false;
LogScaleIsMissingForVarName(name);
if (var_quant_scales_->empty()) {
auto& scales = Get<VarQuantScale>("quant_var_scales");
for (auto name : names) {
if (scales.find(name) == scales.end()) {
present = false;
LogScaleIsMissingForVarName(name);
}
}
} else {
for (auto name : names) {
if (var_quant_scales_->find(name) == var_quant_scales_->end()) {
present = false;
LogScaleIsMissingForVarName(name);
}
}
}
return present;
......@@ -239,12 +248,21 @@ bool CPUQuantizePass::AreScalesPresentForVarNames(
bool CPUQuantizePass::AreScalesPresentForNodes(
std::initializer_list<Node*> nodes) const {
auto& scales = Get<VarQuantScale>("quant_var_scales");
bool present = true;
for (auto node : nodes) {
if (scales.count(node->Name()) == 0) {
present = false;
LogScaleIsMissingForVarNode(node);
if (var_quant_scales_->empty()) {
auto& scales = Get<VarQuantScale>("quant_var_scales");
for (auto node : nodes) {
if (scales.count(node->Name()) == 0) {
present = false;
LogScaleIsMissingForVarNode(node);
}
}
} else {
for (auto node : nodes) {
if (var_quant_scales_->count(node->Name()) == 0) {
present = false;
LogScaleIsMissingForVarNode(node);
}
}
}
return present;
......@@ -252,8 +270,11 @@ bool CPUQuantizePass::AreScalesPresentForNodes(
std::pair<bool, LoDTensor> CPUQuantizePass::GetScaleDataByName(
const std::string& name) const {
auto& scales = Get<VarQuantScale>("quant_var_scales");
return scales.at(name);
if (var_quant_scales_->empty()) {
auto& scales = Get<VarQuantScale>("quant_var_scales");
return scales.at(name);
}
return var_quant_scales_->at(name);
}
std::pair<bool, LoDTensor> CPUQuantizePass::GetScaleDataForNode(
......@@ -290,6 +311,23 @@ bool CPUQuantizePass::IsOpQuantized(const Node* node) const {
});
}
void CPUQuantizePass::GetQuantInfo(Graph* graph) const {
std::unordered_map<std::string, std::vector<float>> info_map{};
GetInfoFromTheFirstOp(graph, "has_quant_info", "var_quant_scales", &info_map);
for (auto iter = info_map.begin(); iter != info_map.end(); iter++) {
LoDTensor tensor;
const int size = static_cast<int>(iter->second.size());
auto* data = tensor.mutable_data<double>({size}, platform::CPUPlace());
for (int i = 0; i < size; i++) {
data[i] = static_cast<double>(iter->second[i]);
}
auto pair = std::make_pair(false, tensor);
var_quant_scales_->insert(std::make_pair(iter->first, pair));
}
}
void CPUQuantizePass::QuantizeConv(Graph* graph,
bool with_residual_data) const {
GraphPatternDetector gpd;
......@@ -1138,6 +1176,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(param_scope(), platform::errors::InvalidArgument(
"Scope cannot be nullptr."));
GetQuantInfo(graph);
QuantizeConv(graph, false /* with_residual_data */);
QuantizeConv(graph, true /* with_residual_data */);
QuantizePool(graph);
......
......@@ -95,6 +95,12 @@ class CPUQuantizePass : public FusePassBase {
bool IsOpQuantized(const Node* node) const;
const std::string name_scope_{"quantize"};
private:
VarQuantScale string_pair_map = {};
VarQuantScale* const var_quant_scales_ = &string_pair_map;
void GetQuantInfo(Graph* graph) const;
};
} // namespace ir
......
......@@ -200,10 +200,8 @@ void QuantDequantMkldnnPass::CollectFakeQuantizeOps(
for (auto* node_input : op_node->inputs) {
if (node_input->Name() == x_var_name) {
fake_quant_in = node_input;
break;
} else if (node_input->Name() == in_scale_name) {
fake_quant_in_scale = node_input;
break;
}
}
......@@ -212,10 +210,8 @@ void QuantDequantMkldnnPass::CollectFakeQuantizeOps(
for (auto* node_output : op_node->outputs) {
if (node_output->Name() == out_var_name) {
fake_quant_out = node_output;
break;
} else if (node_output->Name() == out_scale_name) {
fake_quant_out_scale = node_output;
break;
}
}
......
......@@ -182,6 +182,8 @@ struct Argument {
// A set of op types to enable their bfloat16 kernels
DECL_ARGUMENT_FIELD(bfloat16_enabled_op_types, Bfloat16EnabledOpTypes,
std::unordered_set<std::string>);
DECL_ARGUMENT_FIELD(use_mkldnn_int8, UseMkldnnInt8, bool);
#endif
// Passed from config.
......
......@@ -107,6 +107,10 @@ void IRPassManager::CreatePasses(Argument *argument,
"quantize_excluded_op_ids",
new std::unordered_set<int>(argument->quantize_excluded_op_ids()));
} else if (pass_name == "cpu_quantize_pass") {
if (argument->quantize_enabled_op_types().count("conv2d") ||
argument->quantize_enabled_op_types().count("depthwise_conv2d")) {
pass->Set("data_layout", new std::string("NHWC"));
}
pass->Set("quant_var_scales",
new VarQuantScale(argument->quant_var_scales()));
} else if (pass_name == "cpu_bfloat16_placement_pass") {
......
......@@ -261,6 +261,9 @@ AnalysisConfig::AnalysisConfig(const AnalysisConfig &other) {
CP_MEMBER(use_mkldnn_bfloat16_);
CP_MEMBER(bfloat16_enabled_op_types_);
// Quantization related.
CP_MEMBER(use_mkldnn_int8_);
CP_MEMBER(quantize_enabled_op_types_);
CP_MEMBER(quantize_excluded_op_ids_);
CP_MEMBER(use_mkldnn_quantizer_);
CP_MEMBER(mkldnn_quantizer_config_);
CP_MEMBER(min_input_shape_);
......@@ -435,6 +438,35 @@ void AnalysisConfig::EnableMkldnnBfloat16() {
Update();
}
void AnalysisConfig::EnableMkldnnInt8(
const std::unordered_set<std::string> &op_list) {
#ifdef PADDLE_WITH_MKLDNN
use_mkldnn_int8_ = true;
use_fc_padding_ = false;
if (!op_list.empty()) {
for (auto &type : op_list) {
if (!quantize_enabled_op_types_.count(type)) {
LOG(ERROR) << "There are unsupported operators in the configured "
"quantization operator list. The unsupported operator "
"is: "
<< type;
use_mkldnn_int8_ = false;
break;
}
}
if (use_mkldnn_int8_) {
quantize_enabled_op_types_.clear();
quantize_enabled_op_types_.insert(op_list.begin(), op_list.end());
}
}
#else
LOG(ERROR) << "Please compile with MKLDNN first to use MkldnnInt8";
use_mkldnn_int8_ = false;
#endif
Update();
}
MkldnnQuantizerConfig *AnalysisConfig::mkldnn_quantizer_config() const {
PADDLE_ENFORCE_NOT_NULL(mkldnn_quantizer_config_,
platform::errors::PreconditionNotMet(
......@@ -632,6 +664,20 @@ void AnalysisConfig::Update() {
#endif
}
if (use_mkldnn_int8_) {
#ifdef PADDLE_WITH_MKLDNN
if (!enable_ir_optim_) {
LOG(ERROR) << "EnableMkldnnInt8() only works when IR optimization "
"is enabled.";
} else if (!use_mkldnn_) {
LOG(ERROR) << "EnableMkldnnInt8() only works when MKLDNN "
"is enabled.";
} else {
pass_builder()->EnableMkldnnInt8();
}
#endif
}
#ifdef PADDLE_WITH_MKLDNN
// Do not optimize when mkldnn is on
if (enable_memory_optim_ && !use_mkldnn_) {
......@@ -731,6 +777,9 @@ std::string AnalysisConfig::SerializeInfoCache() {
ss << use_mkldnn_quantizer_;
ss << use_mkldnn_bfloat16_;
for (auto &item : bfloat16_enabled_op_types_) ss << item;
ss << use_mkldnn_int8_;
for (auto &item : quantize_enabled_op_types_) ss << item;
for (auto &item : quantize_excluded_op_ids_) ss << item;
ss << ";";
ss << model_from_memory_;
......
......@@ -949,6 +949,13 @@ void AnalysisPredictor::PrepareArgument() {
LOG(INFO) << "Bfloat16 is enabled";
argument_.SetBfloat16EnabledOpTypes(config_.bfloat16_enabled_op_types_);
}
if (config_.use_mkldnn_int8_) {
LOG(INFO) << "Int8 is enabled";
argument_.SetQuantizeEnabledOpTypes(config_.quantize_enabled_op_types_);
argument_.SetQuantizeExcludedOpIds(config_.quantize_excluded_op_ids_);
argument_.SetQuantVarScales({});
}
#endif
auto passes = config_.pass_builder()->AllPasses();
......
......@@ -712,6 +712,20 @@ struct PD_INFER_DECL AnalysisConfig {
///
void EnableMkldnnQuantizer();
///
/// \brief Turn on MKLDNN int8.
///
/// \param op_list The operator type list.
///
void EnableMkldnnInt8(const std::unordered_set<std::string>& op_list = {});
///
/// \brief A boolean state telling whether to use the MKLDNN Int8.
///
/// \return bool Whether to use the MKLDNN Int8.
///
bool mkldnn_int8_enabled() const { return use_mkldnn_int8_; }
///
/// \brief Turn on MKLDNN bfloat16.
///
......@@ -981,6 +995,26 @@ struct PD_INFER_DECL AnalysisConfig {
std::shared_ptr<MkldnnQuantizerConfig> mkldnn_quantizer_config_;
bool use_mkldnn_bfloat16_{false};
std::unordered_set<std::string> bfloat16_enabled_op_types_;
bool use_mkldnn_int8_{false};
std::unordered_set<int> quantize_excluded_op_ids_{};
std::unordered_set<std::string> quantize_enabled_op_types_{
"concat",
"conv2d",
"depthwise_conv2d",
"elementwise_add",
"elementwise_mul",
"fc",
"matmul",
"nearest_interp",
"nearest_interp_v2",
"pool2d",
"prior_box",
"reshape2",
"transpose2",
"fusion_gru",
"fusion_lstm",
"multi_gru",
"slice"};
// ipu related.
bool use_ipu_{false};
......
......@@ -220,6 +220,10 @@ void GpuPassStrategy::EnableMkldnnBfloat16() {
LOG(ERROR) << "GPU not support MKL-DNN bfloat16";
}
void GpuPassStrategy::EnableMkldnnInt8() {
LOG(ERROR) << "GPU not support MKL-DNN int8";
}
CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
// NOTE the large fusions should be located in the front, so that they will
// not be damaged by smaller ones.
......@@ -339,6 +343,75 @@ void CpuPassStrategy::EnableMkldnnBfloat16() {
#endif
}
void CpuPassStrategy::EnableMkldnnInt8() {
#ifdef PADDLE_WITH_MKLDNN
if (!use_mkldnn_int8_) {
passes_.clear();
passes_.push_back("quant_dequant_mkldnn_pass");
passes_.push_back("layer_norm_fuse_pass");
passes_.push_back("attention_lstm_fuse_pass");
passes_.push_back("seqconv_eltadd_relu_fuse_pass");
passes_.push_back("fc_lstm_fuse_pass");
passes_.push_back("mul_lstm_fuse_pass");
passes_.push_back("fc_gru_fuse_pass");
passes_.push_back("mul_gru_fuse_pass");
passes_.push_back("multi_gru_fuse_pass");
passes_.push_back("multi_gru_seq_fuse_pass");
passes_.push_back("seq_concat_fc_fuse_pass");
passes_.push_back("gpu_cpu_squeeze2_matmul_fuse_pass");
passes_.push_back("gpu_cpu_reshape2_matmul_fuse_pass");
passes_.push_back("gpu_cpu_flatten2_matmul_fuse_pass");
passes_.push_back("matmul_v2_scale_fuse_pass");
passes_.push_back("squared_mat_sub_fuse_pass");
passes_.push_back("is_test_pass");
passes_.push_back("gpu_cpu_map_matmul_v2_to_mul_pass");
passes_.push_back("gpu_cpu_map_matmul_v2_to_matmul_pass");
passes_.push_back("matmul_scale_fuse_pass");
passes_.push_back("gpu_cpu_map_matmul_to_mul_pass");
passes_.push_back("repeated_fc_relu_fuse_pass");
passes_.push_back("mkldnn_placement_pass");
passes_.push_back("depthwise_conv_mkldnn_pass");
passes_.push_back("conv_bn_fuse_pass");
passes_.push_back("conv_eltwiseadd_bn_fuse_pass");
passes_.push_back("conv_transpose_bn_fuse_pass");
passes_.push_back("conv_transpose_eltwiseadd_bn_fuse_pass");
passes_.push_back("conv_bias_mkldnn_fuse_pass");
passes_.push_back("conv_transpose_bias_mkldnn_fuse_pass");
passes_.push_back("conv_elementwise_add_mkldnn_fuse_pass");
passes_.push_back("conv_concat_relu_mkldnn_fuse_pass");
passes_.push_back("conv_relu_mkldnn_fuse_pass");
passes_.push_back("conv_leaky_relu_mkldnn_fuse_pass");
passes_.push_back("conv_relu6_mkldnn_fuse_pass");
passes_.push_back("conv_swish_mkldnn_fuse_pass");
passes_.push_back("conv_hard_swish_mkldnn_fuse_pass");
passes_.push_back("conv_mish_mkldnn_fuse_pass");
passes_.push_back("conv_hard_sigmoid_mkldnn_fuse_pass");
passes_.push_back("conv_gelu_mkldnn_fuse_pass");
passes_.push_back("fc_fuse_pass");
passes_.push_back("repeated_fc_relu_fuse_pass");
passes_.push_back("fc_mkldnn_pass");
passes_.push_back("fc_act_mkldnn_fuse_pass");
passes_.push_back("matmul_transpose_reshape_fuse_pass");
passes_.push_back("matmul_v2_transpose_reshape_fuse_pass");
passes_.push_back("batch_norm_act_fuse_pass");
passes_.push_back("softplus_activation_mkldnn_fuse_pass");
passes_.push_back("compute_propagate_scales_mkldnn_pass");
passes_.push_back("scale_matmul_fuse_pass");
passes_.push_back("reshape_transpose_matmul_mkldnn_fuse_pass");
passes_.push_back("reshape_transpose_matmul_v2_mkldnn_fuse_pass");
passes_.push_back("cpu_quantize_placement_pass");
passes_.push_back("cpu_quantize_pass");
passes_.push_back("cpu_quantize_squash_pass");
passes_.push_back("simplify_with_basic_ops_pass");
passes_.push_back("mkldnn_inplace_pass");
passes_.push_back("runtime_context_cache_pass");
}
use_mkldnn_int8_ = true;
#else
use_mkldnn_int8_ = false;
#endif
}
IpuPassStrategy::IpuPassStrategy() : PassStrategy({}) {
passes_.assign({"inference_process_pass"});
}
......
......@@ -139,6 +139,9 @@ class PD_INFER_DECL PassStrategy : public PaddlePassBuilder {
/// \brief Enable MKLDNN bfloat16.
virtual void EnableMkldnnBfloat16() {}
/// \brief Enable MKLDNN int8.
virtual void EnableMkldnnInt8() {}
/// \brief Check if we are using gpu.
/// \return A bool variable implying whether we are in gpu mode.
bool use_gpu() const { return use_gpu_; }
......@@ -189,6 +192,7 @@ class PD_INFER_DECL CpuPassStrategy : public PassStrategy {
use_mkldnn_ = other.use_mkldnn_;
use_mkldnn_quantizer_ = other.use_mkldnn_quantizer_;
use_mkldnn_bfloat16_ = other.use_mkldnn_bfloat16_;
use_mkldnn_int8_ = other.use_mkldnn_int8_;
}
/// \brief Default destructor.
virtual ~CpuPassStrategy() = default;
......@@ -205,10 +209,14 @@ class PD_INFER_DECL CpuPassStrategy : public PassStrategy {
/// \brief Enable MKLDNN bfloat16.
void EnableMkldnnBfloat16() override;
/// \brief Enable MKLDNN int8.
void EnableMkldnnInt8() override;
protected:
/// \cond Protected
bool use_mkldnn_quantizer_{false};
bool use_mkldnn_bfloat16_{false};
bool use_mkldnn_int8_{false};
/// \endcond
};
......@@ -243,6 +251,9 @@ class PD_INFER_DECL GpuPassStrategy : public PassStrategy {
/// \brief Not supported in GPU mode yet.
void EnableMkldnnBfloat16() override;
/// \brief Not supported in GPU mode yet.
void EnableMkldnnInt8() override;
/// \brief Default destructor.
virtual ~GpuPassStrategy() = default;
......
......@@ -168,7 +168,7 @@ function(inference_analysis_api_test_with_fake_data_run TARGET_NAME test_binary
--disable_mkldnn_fc=${disable_fc})
endfunction()
function(inference_analysis_api_quant_test_run TARGET_NAME test_binary fp32_model_dir int8_model_dir data_path)
function(inference_analysis_api_quant_test_run TARGET_NAME test_binary fp32_model_dir int8_model_dir data_path enable_quant_int8)
inference_analysis_test_run(${TARGET_NAME}
COMMAND ${test_binary}
ARGS --fp32_model=${fp32_model_dir}
......@@ -176,6 +176,7 @@ function(inference_analysis_api_quant_test_run TARGET_NAME test_binary fp32_mode
--infer_data=${data_path}
--batch_size=50
--enable_int8=true
--enable_quant_int8=${enable_quant_int8}
--cpu_num_threads=${CPU_NUM_THREADS_ON_CI}
--with_accuracy_layer=false
--iterations=2)
......@@ -554,7 +555,20 @@ if(WITH_MKLDNN)
download_quant_data_without_verify(${QUANT2_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf.tar.gz")
endif(NOT LINUX)
download_quant_data_without_verify(${QUANT2_INT8_MobileNetV1_MODEL_DIR} "MobileNet_qat_perf_int8.tar.gz")
inference_analysis_api_quant_test_run(test_analyzer_quant_performance_benchmark ${QUANT_IMG_CLASS_TEST_APP} ${QUANT2_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf/float ${QUANT2_INT8_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH})
inference_analysis_api_quant_test_run(test_analyzer_quant_performance_benchmark ${QUANT_IMG_CLASS_TEST_APP} ${QUANT2_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf/float ${QUANT2_INT8_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf_int8 ${IMAGENET_DATA_PATH} false)
# Quant2 MobileNetV1
inference_analysis_api_quant_test_run(test_analyzer_quant2_mobilenetv1_mkldnn ${QUANT_IMG_CLASS_TEST_APP} ${QUANT2_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf/float ${QUANT2_MobileNetV1_MODEL_DIR}/MobileNet_qat_perf/float ${IMAGENET_DATA_PATH} true)
# Quant2 ResNet50 with input/output scales in `fake_quantize_range_abs_max` operators and the `out_threshold` attributes,
# with weight scales in `fake_channel_wise_dequantize_max_abs` operators
set(QUANT2_RESNET50_CHANNELWISE_MODEL_DIR "${QUANT_DATA_DIR}/ResNet50_quant2_channelwise")
set(QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE "ResNet50_qat_channelwise.tar.gz")
if(NOT LINUX)
download_quant_data_without_verify(${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR} ${QUANT2_RESNET50_CHANNELWISE_MODEL_ARCHIVE})
endif(NOT LINUX)
set(QUANT2_RESNET50_MODEL ${QUANT2_RESNET50_CHANNELWISE_MODEL_DIR}/ResNet50_qat_channelwise)
inference_analysis_api_quant_test_run(test_analyzer_quant2_resnet50_channelwise_mkldnn ${QUANT_IMG_CLASS_TEST_APP} ${QUANT2_RESNET50_MODEL} ${QUANT2_RESNET50_MODEL} ${IMAGENET_DATA_PATH} true)
### Other tests
......@@ -774,6 +788,8 @@ if(WITH_MKLDNN)
set_tests_properties(test_analyzer_int8_mobilenetv2 PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_int8_mobilenetv1 PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_int8_mobilenetv3_large PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_quant2_mobilenetv1_mkldnn PROPERTIES TIMEOUT 120)
set_tests_properties(test_analyzer_quant2_resnet50_channelwise_mkldnn PROPERTIES TIMEOUT 120)
endif()
set_tests_properties(lite_resnet50_test PROPERTIES TIMEOUT 120)
......
......@@ -26,8 +26,7 @@ namespace analysis {
void SetConfig(AnalysisConfig *cfg, std::string model_path) {
cfg->SetModel(model_path);
cfg->DisableGpu();
cfg->SwitchIrOptim(false);
cfg->SwitchSpecifyInputNames();
cfg->SwitchIrOptim(true);
cfg->SetCpuMathLibraryNumThreads(FLAGS_cpu_num_threads);
if (FLAGS_enable_mkldnn) cfg->EnableMKLDNN();
}
......@@ -113,9 +112,11 @@ void SetInput(std::vector<std::vector<PaddleTensor>> *inputs,
TEST(Analyzer_quant_image_classification, quantization) {
AnalysisConfig fp32_cfg;
SetConfig(&fp32_cfg, FLAGS_fp32_model);
fp32_cfg.EnableMKLDNN();
AnalysisConfig int8_cfg;
SetConfig(&int8_cfg, FLAGS_int8_model);
if (FLAGS_enable_quant_int8) int8_cfg.EnableMkldnnInt8();
// read data from file and prepare batches with test data
std::vector<std::vector<PaddleTensor>> input_slots_all;
......
......@@ -53,6 +53,7 @@ DEFINE_bool(with_accuracy_layer, true,
DEFINE_bool(enable_fp32, true, "Enable FP32 type prediction");
DEFINE_bool(enable_bf16, false, "Enable BF16 type prediction");
DEFINE_bool(enable_int8, false, "Enable INT8 type prediction");
DEFINE_bool(enable_quant_int8, false, "Enable QUANT INT8 type prediction");
DEFINE_int32(warmup_batch_size, 100, "batch size for quantization warmup");
// setting iterations to 0 means processing the whole dataset
DEFINE_int32(iterations, 0, "number of batches to process");
......
......@@ -695,6 +695,10 @@ void BindAnalysisConfig(py::module *m) {
.def("set_mkldnn_cache_capacity", &AnalysisConfig::SetMkldnnCacheCapacity,
py::arg("capacity") = 0)
.def("set_bfloat16_op", &AnalysisConfig::SetBfloat16Op)
.def("enable_mkldnn_int8", &AnalysisConfig::EnableMkldnnInt8,
py::arg("mkldnn_int8_enabled_op_types") =
std::unordered_set<std::string>({}))
.def("mkldnn_int8_enabled", &AnalysisConfig::mkldnn_int8_enabled)
#endif
.def("set_mkldnn_op", &AnalysisConfig::SetMKLDNNOp)
.def("set_model_buffer", &AnalysisConfig::SetModelBuffer)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册