未验证 提交 966447e3 编写于 作者: W Wojciech Uss 提交者: GitHub

Added support for quantization of fusion_gru (#27518)

上级 0cd4907e
......@@ -20,7 +20,7 @@ SET(MKLDNN_SOURCE_DIR ${THIRD_PARTY_PATH}/mkldnn/src/extern_mkldnn)
SET(MKLDNN_INSTALL_DIR ${THIRD_PARTY_PATH}/install/mkldnn)
SET(MKLDNN_INC_DIR "${MKLDNN_INSTALL_DIR}/include" CACHE PATH "mkldnn include directory." FORCE)
SET(MKLDNN_REPOSITORY https://github.com/oneapi-src/oneDNN.git)
SET(MKLDNN_TAG 64a48f9565aa72f6359917b3406328075a409939)
SET(MKLDNN_TAG 361725600224f41b7347a1c6bee9b04d1e6c14d7)
# Introduce variables:
# * CMAKE_INSTALL_LIBDIR
......
......@@ -1882,9 +1882,9 @@ PDNode *patterns::MultipleQuantize::operator()() {
PDNode *patterns::QuantizePlacement::operator()(
const std::unordered_set<std::string> &quantize_enabled_op_types) {
std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>({"concat", "conv2d", "elementwise_add",
"fc", "matmul", "pool2d", "prior_box",
"relu", "reshape2", "transpose2"});
std::unordered_set<std::string>(
{"concat", "conv2d", "elementwise_add", "fc", "matmul", "pool2d",
"prior_box", "relu", "reshape2", "transpose2", "fusion_gru"});
if (!quantize_enabled_op_types.empty()) {
supported_op_types = quantize_enabled_op_types;
}
......@@ -2280,6 +2280,23 @@ PDNode *patterns::MatmulTransposeReshapePattern::operator()() {
return reshape_out;
}
PDNode *patterns::FusionGru::operator()() {
auto op = pattern->NewNode(op_repr())->assert_is_op("fusion_gru");
auto x = pattern->NewNode(x_repr())->AsInput()->assert_is_op_input(
"fusion_gru", "X");
auto weight_h = pattern->NewNode(weight_h_repr())
->AsInput()
->assert_is_op_input("fusion_gru", "WeightH");
auto weight_x = pattern->NewNode(weight_x_repr())
->AsInput()
->assert_is_op_input("fusion_gru", "WeightX");
auto out = pattern->NewNode(out_repr())
->AsOutput()
->assert_is_op_output("fusion_gru", "Hidden");
op->LinksFrom({x, weight_h, weight_x}).LinksTo({out});
return out;
}
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -1312,6 +1312,21 @@ struct MatmulTransposeReshapePattern : public PatternBase {
PATTERN_DECL_NODE(reshape_out_xshape);
};
// fusion_gru op
// Forward pass for fusion_gru.
// fusion_gru out is a result of the operator.
struct FusionGru : public PatternBase {
FusionGru(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "fusion_gru") {}
PDNode* operator()();
PATTERN_DECL_NODE(op);
PATTERN_DECL_NODE(x);
PATTERN_DECL_NODE(weight_h);
PATTERN_DECL_NODE(weight_x);
PATTERN_DECL_NODE(out);
};
} // namespace patterns
// Link two ir::Nodes from each other.
......
......@@ -63,8 +63,9 @@ enum { U8_MAX = 255, S8_MAX = 127 };
void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
std::string input_name, double scale_to_one,
bool is_unsigned,
std::string scale_attr_name) const {
bool is_input_unsigned,
std::string scale_attr_name, float shift,
std::string shift_attr_name) const {
auto inputs = op->Op()->InputNames();
bool name_found =
std::find(inputs.begin(), inputs.end(), input_name) != inputs.end();
......@@ -72,7 +73,7 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
platform::errors::InvalidArgument(
"Var(%s) isn't the input of the %s operator.",
input_name, op->Op()->Type()));
unsigned max = is_unsigned ? U8_MAX : S8_MAX;
unsigned max = is_input_unsigned ? U8_MAX : S8_MAX;
float scale = scale_to_one * max;
// Create quantize output variable
......@@ -86,7 +87,8 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node->Name()}));
q_desc.SetAttr("Scale", scale);
q_desc.SetAttr("is_negative_input", !is_unsigned);
q_desc.SetAttr("Shift", shift);
q_desc.SetAttr("is_negative_input", !is_input_unsigned);
q_desc.SetAttr("output_format",
Has("data_layout") ? Get<std::string>("data_layout") : "NHWC");
......@@ -103,11 +105,13 @@ void CPUQuantizePass::QuantizeInput(Graph* g, Node* op, Node* input,
IR_NODE_LINK_TO(quantize_out_node, op);
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
}
void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name,
bool are_unsigned,
std::string scale_attr_name) const {
bool are_inputs_unsigned,
std::string scale_attr_name, float shift,
std::string shift_attr_name) const {
auto inputs = op->inputs;
auto output = op->outputs[0];
PADDLE_ENFORCE_GE(inputs.size(), 1,
......@@ -127,7 +131,7 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name,
std::vector<std::string> quantize_out_node_names(inputs.size());
double scale_out = GetScaleValueForNode(output);
unsigned max = are_unsigned ? U8_MAX : S8_MAX;
unsigned max = are_inputs_unsigned ? U8_MAX : S8_MAX;
float scale = scale_out * max;
for (size_t i = 0; i < inputs.size(); i++) {
......@@ -137,10 +141,11 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name,
quantize_out_node_names[i] = quantize_out_nodes[i]->Name();
q_desc.SetAttr("Scale", scale);
q_desc.SetAttr("Shift", shift);
q_desc.SetInput("Input", std::vector<std::string>({inputs[i]->Name()}));
q_desc.SetOutput("Output",
std::vector<std::string>({quantize_out_node_names[i]}));
q_desc.SetAttr("is_negative_input", !are_unsigned);
q_desc.SetAttr("is_negative_input", !are_inputs_unsigned);
auto quantize_op = g->CreateOpNode(&q_desc); // OpDesc will be copied.
// link quantize op
......@@ -154,6 +159,7 @@ void CPUQuantizePass::QuantizeInputs(Graph* g, Node* op, std::string input_name,
op->Op()->SetInput(input_name, quantize_out_node_names);
if (!scale_attr_name.empty()) op->Op()->SetAttr(scale_attr_name, scale);
if (!shift_attr_name.empty()) op->Op()->SetAttr(shift_attr_name, shift);
}
void CPUQuantizePass::DequantizeOutput(Graph* g, Node* op, Node* output,
......@@ -782,6 +788,62 @@ void CPUQuantizePass::QuantizeElementwiseAdd(Graph* graph) const {
quantize_elementwise_add_count);
}
void CPUQuantizePass::QuantizeFusionGru(Graph* graph) const {
GraphPatternDetector gpd;
patterns::FusionGru pattern{gpd.mutable_pattern(), name_scope_};
pattern();
int quantize_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "Quantize fusion_gru op";
GET_IR_NODE_FROM_SUBGRAPH(op, op, pattern);
// skip if should not be quantized
if (!platform::HasOpINT8DataType(op->Op())) {
LogQuantizationDisabled(op);
return;
}
GET_IR_NODE_FROM_SUBGRAPH(x, x, pattern);
GET_IR_NODE_FROM_SUBGRAPH(weight_h, weight_h, pattern);
GET_IR_NODE_FROM_SUBGRAPH(weight_x, weight_x, pattern);
GET_IR_NODE_FROM_SUBGRAPH(out, out, pattern);
if (!AreScalesPresentForNodes(op, {x, weight_h, weight_x})) {
LogCannotQuantizeOp(op);
return;
}
bool is_x_unsigned{false};
auto input_x_scale = GetScaleValueForNode(x, &is_x_unsigned);
double input_x_shift{128.};
if (is_x_unsigned) input_x_shift = 0.;
QuantizeInput(g, op, x, "X", input_x_scale, is_x_unsigned, "Scale_data",
input_x_shift, "Shift_data");
auto weight_scale_tensor = GetScaleTensorForNode(weight_x);
EigenVectorArrayMap eigen_tensor{weight_scale_tensor.data<double>(),
weight_scale_tensor.numel(), 1};
eigen_tensor *= static_cast<double>(S8_MAX);
std::vector<float> scale_weights{
weight_scale_tensor.data<double>(),
weight_scale_tensor.data<double>() + weight_scale_tensor.numel()};
op->Op()->SetAttr("Scale_weights", scale_weights);
// return fp32 data
op->Op()->SetAttr("force_fp32_output", true);
++quantize_count;
};
gpd(graph, handler);
AddStatis(quantize_count);
PrettyLogDetail("--- quantized %d fusion_gru ops", quantize_count);
}
void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
VLOG(3) << "Quantizing the graph.";
PADDLE_ENFORCE_NOT_NULL(
......@@ -801,6 +863,7 @@ void CPUQuantizePass::ApplyImpl(ir::Graph* graph) const {
QuantizeReshape(graph);
QuantizeMatmul(graph);
QuantizeElementwiseAdd(graph);
QuantizeFusionGru(graph);
}
} // namespace ir
......
......@@ -49,31 +49,26 @@ class CPUQuantizePass : public FusePassBase {
void ApplyImpl(ir::Graph* graph) const override;
void QuantizeConv(Graph* graph, bool with_residual_data = false) const;
void QuantizeFc(Graph* graph) const;
void QuantizePool(Graph* graph) const;
void QuantizeConcat(Graph* graph) const;
void QuantizePriorBox(Graph* graph) const;
void QuantizeTranspose(Graph* graph) const;
void QuantizeReshape(Graph* graph) const;
void QuantizeMatmul(Graph* graph) const;
void QuantizeElementwiseAdd(Graph* graph) const;
void QuantizeFusionGru(Graph* graph) const;
void QuantizeInput(Graph* g, Node* op, Node* input, std::string input_name,
double scale_to_one, bool is_unsigned,
std::string scale_attr_name = "") const;
double scale_to_one, bool is_input_unsigned,
std::string scale_attr_name = "", float shift = 0.0,
std::string shift_attr_name = "") const;
// quantize all inputs of given name with the same (minimum) scale
void QuantizeInputs(Graph* g, Node* op, std::string input_name,
bool are_unsigned,
std::string scale_attr_name = "") const;
bool are_inputs_unsigned,
std::string scale_attr_name = "", float shift = 0.0,
std::string shift_attr_name = "") const;
void DequantizeOutput(Graph* g, Node* op, Node* output,
std::string output_name, double scale_to_one,
......
......@@ -91,6 +91,16 @@ void SetOp(ProgramDesc* prog, const std::string& type, const std::string& name,
op->SetAttr("Scale_x", 1.0f);
op->SetAttr("Scale_y", 1.0f);
op->SetAttr("Scale_out", 1.0f);
} else if (type == "fusion_gru") {
op->SetInput("X", {inputs[0]});
op->SetInput("Bias", {inputs[1]});
op->SetInput("WeightX", {inputs[2]});
op->SetInput("WeightH", {inputs[3]});
op->SetOutput("Hidden", {outputs[0]});
op->SetAttr("mkldnn_data_type", mkldnn_data_type);
op->SetAttr("Scale_data", 1.0f);
op->SetAttr("Shift_data", 0.0f);
op->SetAttr("Weight_scale", std::vector<float>{1.0f});
}
}
......@@ -389,6 +399,77 @@ TEST(CpuQuantizePass, transpose) {
quant_count, dequant_count, added_nodes_count, 2.0f * 127);
}
static const std::initializer_list<std::string> variable_names_fusion_gru = {
"x", "wx", "wh", "b", "h"};
// x->Fusion_gru->h
ProgramDesc BuildProgramDescFusionGru() {
ProgramDesc prog;
for (auto& v : variable_names_transpose) {
auto* var = prog.MutableBlock(0)->Var(v);
if (v.find("wx") == 0 || v.find("wh") || v.find("b")) {
var->SetPersistable(true);
}
}
SetOp(&prog, "fusion_gru", "Fusion_gru", {"x", "wx", "wh", "b"}, {"h"}, true,
"int8");
return prog;
}
void MainTestFusionGru(const ProgramDesc& prog, int gru_count, int quant_count,
int dequant_count, int added_nodes_count, float scale,
float shift) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num, current_nodes_num;
PreparePass(&graph, prog, variable_names_fusion_gru, &original_nodes_num,
&current_nodes_num);
int quantize_nodes_count = 0;
int dequantize_nodes_count = 0;
int gru_nodes_count = 0;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op = node->Op();
if (op->Type() == "fusion_gru") {
gru_nodes_count++;
auto op_name = BOOST_GET_CONST(std::string, op->GetAttr("name"));
EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Scale_data")), scale)
<< "Scale_data for node '" + op_name + "'.";
EXPECT_EQ(BOOST_GET_CONST(float, op->GetAttr("Shift_data")), shift)
<< "Shift_data for node '" + op_name + "'.";
EXPECT_EQ(BOOST_GET_CONST(std::vector<float>,
op->GetAttr("Scale_weights"))[0],
scale)
<< "Scale_weights for node '" + op_name + "'.";
EXPECT_EQ(BOOST_GET_CONST(bool, op->GetAttr("force_fp32_output")), true)
<< "force_fp32_output for node '" + op_name + "'.";
} else if (op->Type() == "quantize") {
quantize_nodes_count++;
} else if (op->Type() == "dequantize") {
dequantize_nodes_count++;
}
}
}
EXPECT_EQ(gru_nodes_count, gru_count);
EXPECT_EQ(quantize_nodes_count, quant_count);
EXPECT_EQ(dequantize_nodes_count, dequant_count);
EXPECT_EQ(original_nodes_num + added_nodes_count, current_nodes_num);
}
TEST(CpuQuantizePass, fusion_gru) {
// x->Fusion_gru->h
int gru_count = 1;
int quant_count = 1;
int dequant_count = 0;
// 1 Quant + 1 IN + 0 DeQuant + 0 OUT
int added_nodes_count = 1 + 1 + 0 + 0;
MainTestFusionGru(BuildProgramDescFusionGru(), gru_count, quant_count,
dequant_count, added_nodes_count, 2. * 127, 128.);
}
static const std::initializer_list<std::string> variable_names_reshape = {
"a", "w1", "b", "c", "d", "e", "f"};
......
......@@ -76,6 +76,8 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
BOOST_GET_CONST(float, dequant_op->Op()->GetAttr("Scale"));
float quant_scale =
BOOST_GET_CONST(float, quant_op->Op()->GetAttr("Scale"));
float dequant_shift = dequant_op->Op()->GetAttrIfExists<float>("Shift");
float quant_shift = quant_op->Op()->GetAttrIfExists<float>("Shift");
PADDLE_ENFORCE_NE(
nodes_keep_counter->find(dequant_out), nodes_keep_counter->end(),
platform::errors::NotFound("The dequant output node is not found."));
......@@ -83,7 +85,7 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
// check if dequantize op should be kept or removed, decrease the counter
bool keep_dequant = (*nodes_keep_counter)[dequant_out]-- > 1;
if (dequant_scale == quant_scale) {
if (dequant_scale == quant_scale && dequant_shift == quant_shift) {
// squash dequantize-quantize to nothing
auto quant_out_var_name = quant_out->Name();
auto next_op_inputs = next_op_desc->InputNames();
......@@ -110,7 +112,9 @@ void CPUQuantizeSquashPass::DequantQuantSquash(
desc.SetInput("Input", std::vector<std::string>({dequant_in->Name()}));
desc.SetOutput("Output", std::vector<std::string>({quant_out->Name()}));
desc.SetAttr("Scale_in", dequant_scale);
desc.SetAttr("Shift_in", dequant_shift);
desc.SetAttr("Scale_out", quant_scale);
desc.SetAttr("Shift_out", quant_shift);
auto requant_op = g->CreateOpNode(&desc);
......@@ -293,6 +297,7 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
}));
auto* first_quant_out = first_quant_op->outputs[0];
float scale = first_quant_op->Op()->GetAttrIfExists<float>("Scale");
float shift = first_quant_op->Op()->GetAttrIfExists<float>("Shift");
PADDLE_ENFORCE_NE(scale, 0,
platform::errors::InvalidArgument(
......@@ -302,7 +307,8 @@ void CPUQuantizeSquashPass::MultipleQuantizeSquash(Graph* graph) const {
auto quant_op = prev_out->outputs[iter];
if (quant_op->IsOp() && quant_op->Op()->Type() == "quantize" &&
quant_op->id() != first_quant_op->id() &&
quant_op->Op()->GetAttrIfExists<float>("Scale") == scale) {
quant_op->Op()->GetAttrIfExists<float>("Scale") == scale &&
quant_op->Op()->GetAttrIfExists<float>("Shift") == shift) {
auto quant_out = quant_op->outputs[0];
auto last_op = quant_out->outputs[0];
......
......@@ -95,7 +95,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
// Create memory descriptors
auto input_md = MKLDNNMemDesc({Ti, N, IC}, MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
MKLDNNMemoryFormat::ntc);
auto weight_x_md =
MKLDNNMemDesc({L, D, IC, G, OC}, weights_dt, MKLDNNMemoryFormat::any);
auto weight_h_md =
......@@ -103,7 +103,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
auto bias_md = MKLDNNMemDesc({L, D, G, OC}, MKLDNNGetDataType<float>(),
MKLDNNMemoryFormat::ldgo);
auto hidden_md = MKLDNNMemDesc({Ti, N, OC}, MKLDNNGetDataType<T_out>(),
MKLDNNMemoryFormat::any);
MKLDNNMemoryFormat::ntc);
auto h0_md = MKLDNNMemDesc({L, D, N, OC}, MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::ldnc);
......
......@@ -66,6 +66,7 @@ class Quant2Int8MkldnnPass(object):
self._fc_ops = ['fc']
self._relu_ops = ['relu', 'relu6']
self._matmul_ops = ['matmul']
self._gru_ops = ['fusion_gru']
self._weight_scales = {}
# Collect the Input and Output sclaes from Fake quant models
self._var_quant_scales = {}
......@@ -449,8 +450,43 @@ class Quant2Int8MkldnnPass(object):
self._var_quant_scales[weight_var_name] = (use_unsigned_int,
lod_tensor)
def _compute_gru_weight_scales(wx_name, wh_name):
for op in graph.all_op_nodes():
if op.op().type() in self._gru_ops:
wx_var_name = op.input(wx_name)[0]
wh_var_name = op.input(wh_name)[0]
wx = np.array(self._load_param(self._scope, wx_var_name))
wh = np.array(self._load_param(self._scope, wh_var_name))
OC = wh.shape[0]
scale_ur = 1.0 / np.max(np.abs(
np.concatenate(
[
wx[:, :2 * OC], wh.flatten()[:2 * OC * OC]
.reshape(OC, 2 * OC)
],
axis=0)),
axis=0)
scale_o = 1.0 / np.max(np.abs(
np.concatenate(
[
wx[:, 2 * OC:], wh.flatten()[2 * OC * OC:]
.reshape(OC, OC)
],
axis=0)),
axis=0)
gru_weights_scale = np.concatenate(
[scale_ur, scale_o]).astype('float')
lod_tensor = self._convert_scale2tensor(gru_weights_scale)
use_unsigned_int = False
self._var_quant_scales[wx_var_name] = (use_unsigned_int,
lod_tensor)
_compute_var_scales(self._conv_ops, "Filter", axis=1)
_compute_var_scales(self._fc_ops, "W", axis=0)
_compute_var_scales(self._gru_ops, "WeightH", axis=0)
_compute_gru_weight_scales("WeightX", "WeightH")
return graph
def _find_avg_pooling_ids(self, graph):
......
......@@ -98,18 +98,16 @@ function(download_quant_model install_dir data_file)
endif()
endfunction()
function(save_quant_ic_model_test target quant_model_dir fp32_model_save_path int8_model_save_path)
function(save_quant_ic_model_test target quant_model_dir int8_model_save_path)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py
ARGS --quant_model_path ${quant_model_dir}
--fp32_model_save_path ${fp32_model_save_path}
--int8_model_save_path ${int8_model_save_path}
--debug)
endfunction()
function(save_quant_nlp_model_test target quant_model_dir fp32_model_save_path int8_model_save_path ops_to_quantize)
function(save_quant_nlp_model_test target quant_model_dir int8_model_save_path ops_to_quantize)
py_test(${target} SRCS ${CMAKE_CURRENT_SOURCE_DIR}/save_quant_model.py
ARGS --quant_model_path ${quant_model_dir}
--fp32_model_save_path ${fp32_model_save_path}
--int8_model_save_path ${int8_model_save_path}
--ops_to_quantize ${ops_to_quantize})
endfunction()
......@@ -227,8 +225,6 @@ if(LINUX AND WITH_MKLDNN)
set(NLP_LABLES_PATH "${NLP_DATA_DIR}/Ernie_dataset/label.xnli.dev")
download_quant_data(${NLP_DATA_DIR} ${NLP_DATA_ARCHIVE})
set(QUANT2_NLP_OPS_TO_QUANTIZE "fc,reshape2,transpose2,matmul,elementwise_add")
# Quant2 Ernie
set(QUANT2_ERNIE_MODEL_ARCHIVE "ernie_qat.tar.gz")
set(QUANT2_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_quant2")
......@@ -236,17 +232,25 @@ if(LINUX AND WITH_MKLDNN)
set(FP32_ERNIE_MODEL_ARCHIVE "ernie_fp32_model.tar.gz")
set(FP32_ERNIE_MODEL_DIR "${QUANT_INSTALL_DIR}/Ernie_float")
download_quant_fp32_model(${FP32_ERNIE_MODEL_DIR} ${FP32_ERNIE_MODEL_ARCHIVE})
inference_quant2_int8_nlp_test(test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QUANT2_NLP_OPS_TO_QUANTIZE})
set(QUANT2_ERNIE_OPS_TO_QUANTIZE "fc,reshape2,transpose2,matmul,elementwise_add")
inference_quant2_int8_nlp_test(test_quant2_int8_ernie_mkldnn ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${FP32_ERNIE_MODEL_DIR}/ernie_fp32_model ${NLP_DATA_PATH} ${NLP_LABLES_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE})
# Quant2 GRU
set(QUANT2_GRU_MODEL_ARCHIVE "GRU_quant_acc.tar.gz")
set(QUANT2_GRU_MODEL_DIR "${QUANT_INSTALL_DIR}/GRU_quant2")
download_quant_model(${QUANT2_GRU_MODEL_DIR} ${QUANT2_GRU_MODEL_ARCHIVE})
set(QUANT2_GRU_OPS_TO_QUANTIZE "fusion_gru")
### Save FP32 model or INT8 model from Quant model
set(QUANT2_INT8_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_int8")
set(QUANT2_FP32_RESNET50_SAVE_PATH "${QUANT_INSTALL_DIR}/ResNet50_quant2_fp32")
save_quant_ic_model_test(save_quant2_model_resnet50 ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QUANT2_FP32_RESNET50_SAVE_PATH} ${QUANT2_INT8_RESNET50_SAVE_PATH})
save_quant_ic_model_test(save_quant2_model_resnet50 ${QUANT2_RESNET50_MODEL_DIR}/ResNet50_qat_perf/float ${QUANT2_INT8_RESNET50_SAVE_PATH})
set(QUANT2_INT8_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8")
set(QUANT2_FP32_ERNIE_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_fp32")
save_quant_nlp_model_test(save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_FP32_ERNIE_SAVE_PATH} ${QUANT2_INT8_ERNIE_SAVE_PATH} ${QUANT2_NLP_OPS_TO_QUANTIZE})
save_quant_nlp_model_test(save_quant2_model_ernie ${QUANT2_ERNIE_MODEL_DIR}/Ernie_qat/float ${QUANT2_INT8_ERNIE_SAVE_PATH} ${QUANT2_ERNIE_OPS_TO_QUANTIZE})
set(QUANT2_INT8_GRU_SAVE_PATH "${QUANT_INSTALL_DIR}/GRU_quant2_int8")
save_quant_nlp_model_test(save_quant2_model_gru ${QUANT2_GRU_MODEL_DIR}/GRU_quant_acc ${QUANT2_INT8_GRU_SAVE_PATH} ${QUANT2_GRU_OPS_TO_QUANTIZE})
# Convert Quant2 model to dot and pdf files
set(QUANT2_INT8_ERNIE_DOT_SAVE_PATH "${QUANT_INSTALL_DIR}/Ernie_quant2_int8_dot_file")
......
......@@ -45,9 +45,10 @@ class TestFusionGRUINT8MKLDNNOp(OpTest):
# Input data
x_f32 = np.random.rand(T, self.IC).astype('float32') * 2 - 1
scale_data = 63
shift_data = 64
x_u8 = (x_f32 * scale_data + shift_data).astype(np.uint8)
scale_data = 63.0
shift_data = 64.0
x_u8 = np.rint(x_f32 * scale_data + shift_data).astype(np.uint8)
# x_u8 = (x_f32 * scale_data + shift_data).astype(np.uint8)
# WeightX/WeightH data
wx = np.random.rand(self.IC, 3 * self.OC).astype('float32') * 2 - 1
......@@ -58,22 +59,23 @@ class TestFusionGRUINT8MKLDNNOp(OpTest):
# WeightX data shape in PP: [IC, 3 * OC]
# WeightH data shape in PP: [OC, 2 * OC] + [OC, OC]
# Scales shape in oneDNN: [3, OC]
scale_ur = 63 / np.max(np.abs(
s8_max = 127.0
scale_ur = s8_max / np.max(np.abs(
np.concatenate(
[
wx[:, :2 * self.OC], wh.flatten()[:2 * self.OC * self.OC]
.reshape(self.OC, 2 * self.OC)
],
axis=0)),
axis=0)
scale_o = 63 / np.max(np.abs(
axis=0)
scale_o = s8_max / np.max(np.abs(
np.concatenate(
[
wx[:, 2 * self.OC:], wh.flatten()[2 * self.OC * self.OC:]
.reshape(self.OC, self.OC)
],
axis=0)),
axis=0)
axis=0)
scale_weights = np.concatenate([scale_ur, scale_o]).astype('float')
......@@ -102,7 +104,9 @@ class TestFusionGRUINT8MKLDNNOp(OpTest):
self.outputs = {'Hidden': (hidden_f32, self.lod)}
else:
self.error_margin = 1
hidden_u8 = (hidden_f32 * scale_data + shift_data).astype(np.uint8)
hidden_u8 = np.rint(hidden_f32 * scale_data + shift_data).astype(
np.uint8)
# hidden_u8 = (hidden_f32 * scale_data + shift_data).astype(np.uint8)
self.outputs = {'Hidden': (hidden_u8, self.lod)}
self.attrs = {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册