未验证 提交 1ee237c1 编写于 作者: L lidanqing 提交者: GitHub

add lstm qat models scales (#35382)

上级 4e233712
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <string> #include <string>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Scope; class Scope;
...@@ -335,6 +335,9 @@ void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -335,6 +335,9 @@ void FCGRUFusePass::ApplyImpl(ir::Graph* graph) const {
graph, name_scope_, param_scope(), true /*with_fc_bias*/); graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
string::PrettyLogDetail("--- fused %d pairs of fc gru patterns",
fusion_count);
} }
} // namespace ir } // namespace ir
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <string> #include <string>
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -348,6 +349,9 @@ void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -348,6 +349,9 @@ void FCLstmFusePass::ApplyImpl(ir::Graph* graph) const {
BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/); BuildFusion(graph, name_scope_, param_scope(), true /*with_fc_bias*/);
AddStatis(fusion_count); AddStatis(fusion_count);
string::PrettyLogDetail("--- fused %d pairs of fc lstm patterns",
fusion_count);
} }
} // namespace ir } // namespace ir
......
...@@ -97,6 +97,19 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name, ...@@ -97,6 +97,19 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetInput("WeightX", {inputs[2]}); op->SetInput("WeightX", {inputs[2]});
op->SetInput("WeightH", {inputs[3]}); op->SetInput("WeightH", {inputs[3]});
op->SetOutput("Hidden", {outputs[0]}); op->SetOutput("Hidden", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
op->SetAttr("Scale_data", 1.0f);
op->SetAttr("Shift_data", 0.0f);
op->SetAttr("Weight_scale", std::vector<float>{1.0f});
} else if (type == "fusion_lstm") {
op->SetInput("X", {inputs[0]});
op->SetInput("Bias", {inputs[1]});
op->SetInput("WeightX", {inputs[2]});
op->SetInput("WeightH", {inputs[3]});
op->SetOutput("Hidden", {outputs[0]});
op->SetOutput("Cell", {outputs[1]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type); op->SetAttr("mkldnn_data_type", mkldnn_data_type);
op->SetAttr("Scale_data", 1.0f); op->SetAttr("Scale_data", 1.0f);
op->SetAttr("Shift_data", 0.0f); op->SetAttr("Shift_data", 0.0f);
...@@ -418,6 +431,25 @@ ProgramDesc BuildProgramDescFusionGru() { ...@@ -418,6 +431,25 @@ ProgramDesc BuildProgramDescFusionGru() {
return prog; return prog;
} }
static const std::initializer_list<std::string> variable_names_fusion_lstm = {
"x", "wx", "wh", "b", "h", "c"};
// (x, wx, wh, b)->Fusion_lstm_1->h
ProgramDesc BuildProgramDescFusionLSTM() {
ProgramDesc prog;
for (auto& v : variable_names_transpose) {
auto* var = prog.MutableBlock(0)->Var(v);
if (v.find("wx") == 0 || v.find("wh") || v.find("b")) {
var->SetPersistable(true);
}
}
SetOp(&prog, "fusion_lstm", "Fusion_lstm_1", {"x", "wx", "wh", "b"},
{"h", "c"}, true, "int8");
return prog;
}
void MainTestFusionGru(const ProgramDesc& prog, int gru_count, int quant_count, void MainTestFusionGru(const ProgramDesc& prog, int gru_count, int quant_count,
int dequant_count, int added_nodes_count, float scale, int dequant_count, int added_nodes_count, float scale,
float shift) { float shift) {
...@@ -470,6 +502,59 @@ TEST(CpuQuantizePass, fusion_gru) { ...@@ -470,6 +502,59 @@ TEST(CpuQuantizePass, fusion_gru) {
dequant_count, added_nodes_count, 2. * 127, 128.); dequant_count, added_nodes_count, 2. * 127, 128.);
} }
void MainTestFusionLSTM(const ProgramDesc& prog, int expect_lstm_count,
int quant_count, int dequant_count,
int added_nodes_count, float scale, float shift) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names_fusion_lstm, &original_nodes_num,
&current_nodes_num);
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
int lstm_nodes_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "fusion_lstm") {
lstm_nodes_count++;
auto op_name = BOOST_GET_CONST(std::string, op->GetAttr("name"));
EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Scale_data")), scale)
<< "Scale_data for node '" + op_name + "'.";
EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Shift_data")), shift)
<< "Shift_data for node '" + op_name + "'.";
EXPECT_EQ(BOOST_GET_CONST(std::vector<float>,
op->GetAttr("Scale_weights"))[0],
scale)
<< "Scale_weights for node '" + op_name + "'.";
EXPECT_EQ(BOOST_GET_CONST(bool, op->GetAttr("force_fp32_output")), true)
<< "force_fp32_output for node '" + op_name + "'.";
} else if (op->Type() == "quantize") {
quantize_nodes_count++;
} else if (op->Type() == "dequantize") {
dequantize_nodes_count++;
}
}
}
EXPECT_EQ(lstm_nodes_count, expect_lstm_count);
EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
}
TEST(CpuQuantizePass, fusion_lstm) {
// (x, wx, wh, b)->Fusion_lstm->h
int expect_lstm_count = 1;
int expect_quant_count = 1;
int dequant_count = 0;
// 1 Quant + 1 IN + 0 DeQuant + 0 OUT
int added_nodes_count = 1 + 1 + 0 + 0;
MainTestFusionLSTM(BuildProgramDescFusionLSTM(), expect_lstm_count,
expect_quant_count, dequant_count, added_nodes_count,
2. * 127, 128.);
}
const std::vector<std::string> churn_out_vars(ProgramDesc* prog, const std::vector<std::string> churn_out_vars(ProgramDesc* prog,
const std::string& prefix, const std::string& prefix,
int number) { int number) {
......
...@@ -71,6 +71,7 @@ class Quant2Int8MkldnnPass(object): ...@@ -71,6 +71,7 @@ class Quant2Int8MkldnnPass(object):
self._relu_ops = ['relu', 'relu6'] self._relu_ops = ['relu', 'relu6']
self._matmul_ops = ['matmul'] self._matmul_ops = ['matmul']
self._gru_ops = ['fusion_gru', 'multi_gru'] self._gru_ops = ['fusion_gru', 'multi_gru']
self._lstm_ops = ['fusion_lstm']
self._weight_thresholds = {} self._weight_thresholds = {}
# Collect the Input and Output sclaes from Fake quant models # Collect the Input and Output sclaes from Fake quant models
self._var_quant_scales = {} self._var_quant_scales = {}
...@@ -535,10 +536,38 @@ class Quant2Int8MkldnnPass(object): ...@@ -535,10 +536,38 @@ class Quant2Int8MkldnnPass(object):
self._var_quant_scales[wx_var_name] = (use_unsigned_int, self._var_quant_scales[wx_var_name] = (use_unsigned_int,
lod_tensor) lod_tensor)
def _compute_single_lstm_weight_scales(wx_var_name, wh_var_name):
wx = np.array(self._load_param(self._scope, wx_var_name))
wh = np.array(self._load_param(self._scope, wh_var_name))
lstm_weights_scale = 1.0 / np.max(
np.abs(np.concatenate(
[wx[:, :], wh[:, :]], axis=0)), axis=0)
lstm_weights_scale = lstm_weights_scale.astype('float')
return self._convert_scale2tensor(lstm_weights_scale)
def _compute_lstm_weight_scales(wx_name, wh_name):
for op in graph.all_op_nodes():
if op.op().type() in self._lstm_ops:
assert len(op.input(wx_name)) == len(
op.input(wh_name)
), 'Mismatch in number of weights inputs ({} for WeightX vs. {} for WeightH).'.format(
len(op.input(wx_name)), len(op.input(wh_name)))
for i, wx_var_name in enumerate(op.input(wx_name)):
wh_var_name = op.input(wh_name)[i]
use_unsigned_int = False
lod_tensor = _compute_single_lstm_weight_scales(
wx_var_name, wh_var_name)
self._var_quant_scales[wx_var_name] = (use_unsigned_int,
lod_tensor)
_compute_var_scales(self._conv_ops, "Filter", axis=1) _compute_var_scales(self._conv_ops, "Filter", axis=1)
_compute_var_scales(self._fc_ops, "W", axis=0) _compute_var_scales(self._fc_ops, "W", axis=0)
_compute_var_scales(self._gru_ops, "WeightH", axis=0) _compute_var_scales(self._gru_ops, "WeightH", axis=0)
_compute_var_scales(self._lstm_ops, "WeightH", axis=0)
_compute_gru_weight_scales("WeightX", "WeightH") _compute_gru_weight_scales("WeightX", "WeightH")
_compute_lstm_weight_scales("WeightX", "WeightH")
return graph return graph
def _find_avg_pooling_ids(self, graph): def _find_avg_pooling_ids(self, graph):
......
...@@ -265,6 +265,12 @@ if(LINUX AND WITH_MKLDNN) ...@@ -265,6 +265,12 @@ if(LINUX AND WITH_MKLDNN)
download_quant_model(${QUANT2_GRU_MODEL_DIR} ${QUANT2_GRU_MODEL_ARCHIVE} cf207f8076dcfb8b74d8b6bdddf9090c) download_quant_model(${QUANT2_GRU_MODEL_DIR} ${QUANT2_GRU_MODEL_ARCHIVE} cf207f8076dcfb8b74d8b6bdddf9090c)
set(QUANT2_GRU_OPS_TO_QUANTIZE "multi_gru") set(QUANT2_GRU_OPS_TO_QUANTIZE "multi_gru")
# Quant2 LSTM
set(QUANT2_LSTM_MODEL_ARCHIVE "lstm_quant.tar.gz")
set(QUANT2_LSTM_MODEL_DIR "${QUANT_INSTALL_DIR}/lstm_quant_test")
download_quant_model(${QUANT2_LSTM_MODEL_DIR} ${QUANT2_LSTM_MODEL_ARCHIVE} 40a693803b12ee9e251258f32559abcb)
set(QUANT2_LSTM_OPS_TO_QUANTIZE "fusion_lstm")
### Save FP32 model or INT8 model from Quant model ### Save FP32 model or INT8 model from Quant model
set(QUANT2_INT8_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8") set(QUANT2_INT8_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8")
...@@ -276,6 +282,9 @@ if(LINUX AND WITH_MKLDNN) ...@@ -276,6 +282,9 @@ if(LINUX AND WITH_MKLDNN)
set(QUANT2_INT8_GRU_SAVE_PATH "${QUANT_INSTALL_DIR}/GRU_quant2_int8") set(QUANT2_INT8_GRU_SAVE_PATH "${QUANT_INSTALL_DIR}/GRU_quant2_int8")
save_quant_nlp_model_test(save_quant2_model_gru ${QUANT2_GRU_MODEL_DIR}/GRU_quant_acc ${QUANT2_INT8_GRU_SAVE_PATH} ${QUANT2_GRU_OPS_TO_QUANTIZE}) save_quant_nlp_model_test(save_quant2_model_gru ${QUANT2_GRU_MODEL_DIR}/GRU_quant_acc ${QUANT2_INT8_GRU_SAVE_PATH} ${QUANT2_GRU_OPS_TO_QUANTIZE})
set(QUANT2_INT8_LSTM_SAVE_PATH "${QUANT_INSTALL_DIR}/lstm_quant2_int8")
save_quant_nlp_model_test(save_quant2_model_lstm ${QUANT2_LSTM_MODEL_DIR}/lstm_quant ${QUANT2_INT8_LSTM_SAVE_PATH} ${QUANT2_LSTM_OPS_TO_QUANTIZE})
# Convert Quant2 model to dot and pdf files # Convert Quant2 model to dot and pdf files
set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file") set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file")
convert_model2dot_test(convert_model2dot_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_quant2_int8") convert_model2dot_test(convert_model2dot_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_DOT_SAVE_PATH} "Ernie_quant2_int8")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册