提交 51e9898d 编写于 作者: C cc 提交者: GitHub

Modify quant_dequant_fuse_pass to process quant_dequant_op, test=develop (#3341)

上级 add162dc
......@@ -44,7 +44,10 @@ DEFINE_string(input_shape,
"set input shapes according to the model, "
"separated by colon and comma, "
"such as 1,3,244,244");
DEFINE_string(input_img_path, "", "the path of input image");
DEFINE_string(input_img_path,
"",
"the path of input image, if not set "
"input_img_path, the input of model will be 1.0.");
DEFINE_int32(warmup, 0, "warmup times");
DEFINE_int32(repeats, 1, "repeats times");
DEFINE_int32(power_mode,
......@@ -57,16 +60,11 @@ DEFINE_int32(power_mode,
DEFINE_int32(threads, 1, "threads num");
DEFINE_string(result_filename,
"result.txt",
"save benchmark "
"result to the file");
"save the inference time to the file.");
DEFINE_bool(run_model_optimize,
false,
"if set true, apply model_optimize_tool to "
"model and use optimized model to test. ");
DEFINE_bool(is_quantized_model,
false,
"if set true, "
"test the performance of the quantized model. ");
namespace paddle {
namespace lite_api {
......@@ -87,10 +85,6 @@ void OutputOptModel(const std::string& save_optimized_model_dir) {
std::vector<Place> vaild_places = {
Place{TARGET(kARM), PRECISION(kFloat)},
};
if (FLAGS_is_quantized_model) {
vaild_places.insert(vaild_places.begin(),
Place{TARGET(kARM), PRECISION(kInt8)});
}
config.set_valid_places(vaild_places);
auto predictor = lite_api::CreatePaddlePredictor(config);
......@@ -181,8 +175,8 @@ void Run(const std::vector<int64_t>& input_shape,
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir == "" || FLAGS_result_filename == "") {
LOG(INFO) << "please run ./benchmark_bin --help to obtain usage.";
if (FLAGS_model_dir == "") {
LOG(INFO) << "Please run ./benchmark_bin --help to obtain usage.";
exit(0);
}
......
......@@ -295,6 +295,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
inner_places.emplace_back(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
// Analysis whether the modle is quantized.
// For quantized model, add place(arm, int8) to inner_places
const std::vector<std::string> quant_dequant_op = {
"fake_quantize_abs_max",
"fake_quantize_range_abs_max",
......@@ -317,7 +319,8 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
}
}
if (is_quantized_model) {
inner_places.emplace_back(Place{TARGET(kARM), PRECISION(kInt8)});
inner_places.insert(inner_places.begin(),
Place{TARGET(kARM), PRECISION(kInt8)});
}
Program program(desc, scope_, inner_places);
......
......@@ -44,11 +44,9 @@ void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fuser(graph.get());
}
// delete quant_dequant_node
for (auto op_type : {"pool2d", "softmax", "elementwise_add"}) {
fusion::DeleteQuantDequantOpFuser fuser(op_type);
fuser(graph.get());
}
// process quant_dequant_node
fusion::DeleteQuantDequantOpFuser dqd_fuser;
dqd_fuser(graph.get());
}
} // namespace mir
......
......@@ -50,7 +50,7 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node");
// obtain values, save values and relink node
// obtain scale, save attrs and relink node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope();
......@@ -58,11 +58,22 @@ void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto in_act_name = input_act_node->arg()->name;
auto out_act_name = output_act_node->arg()->name;
auto outlinks = output_act_node->outlinks;
for (auto* quantized_node : outlinks) {
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("input_scale", scale_value);
// save input scale in quantized op by input argname + index
auto op_desc = *quantized_node->stmt()->mutable_op_info();
std::string argname;
int index;
op_desc.GetInputArgname(out_act_name, &argname);
op_desc.GetInputIndex(out_act_name, &index);
op_desc.SetAttr<float>(argname + std::to_string(index) + "_input_scale",
scale_value);
op_desc.SetAttr<float>("input_scale", scale_value); // save it for now
op_desc.SetAttr<int>("bit_length", bit_length);
op_desc.UpdateAllInputs(out_act_name, in_act_name);
quantized_node->stmt()->ResetOp(op_desc, graph->valid_places());
IR_NODE_LINK_TO(input_act_node, quantized_node)
}
......@@ -125,19 +136,18 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
auto* dequant_op = matched.at("dequant_op");
auto* dequant_op_out = matched.at("dequant_op_out");
// obtain input_scale and weight_scale
// obtain weight_scale from max_range
auto* scope = quantized_op->stmt()->op()->scope();
auto& valid_places = quantized_op->stmt()->op()->valid_places();
int bit_length = quantized_op->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
float input_scale =
quantized_op->stmt()->op_info()->GetAttr<float>("input_scale");
float max_range = dequant_op->stmt()->op_info()->GetAttr<float>("max_range");
float whole_weight_scale =
static_cast<float>(range * range) / max_range / range;
// max_range = range * range / max(abs(weight))
// weight_scale = range * range / (range * range / max(abs(weight))) / range
// = max(abs(weight)) / range
// As: max_range = range * range / max(abs(weight))
// So: whole_weight_scale
// = range * range / (range * range / max(abs(weight))) / range
// = max(abs(weight)) / range
// set op desc
cpp::OpDesc op_desc = *quantized_op->stmt()->op_info();
......@@ -153,7 +163,7 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
// Conv weight shape: Cout * Cin * kh * hw, the weight_scale_size should
// be Cout.
weight_scale_size = quantized_weight_t->dims()[0];
} else if (quantized_op_type_ == "mul") {
} else if (quantized_op_type_ == "mul" || quantized_op_type_ == "matmul") {
op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
// Fc weight: Cin * Cout, the weight_scale_size should be Cout.
......@@ -163,7 +173,6 @@ void DequantOpFuser::InsertNewNode(SSAGraph* graph,
weight_scale.push_back(whole_weight_scale);
}
op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("input_scale", input_scale);
op_desc.SetAttr("weight_scale", weight_scale);
// change the weight from the float type to int8 type.
......@@ -209,6 +218,7 @@ void ChannelWiseDequantOpFuser::BuildPattern() {
->assert_is_op_output(quantized_op_type_)
->assert_is_op_input(dequant_op_type, "X")
->AsIntermediate();
// The scale var_node of input activation is deleted in DeleteQuantOpFuser
auto* dequant_op_channel_scale = VarNode("dequant_op_channel_scale")
->assert_is_op_input(dequant_op_type)
->AsIntermediate();
......@@ -237,11 +247,9 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
auto* dequant_op = matched.at("dequant_op");
auto* dequant_op_out = matched.at("dequant_op_out");
// obtain input_scale and weight_scale
// obtain input weight_scale from fake_dequant op
auto* scope = quantized_op->stmt()->op()->scope();
auto& valid_places = quantized_op->stmt()->op()->valid_places();
float input_scale =
quantized_op->stmt()->op_info()->GetAttr<float>("input_scale");
std::vector<float> weight_scale;
std::vector<int> quant_bits =
......@@ -258,11 +266,15 @@ void ChannelWiseDequantOpFuser::InsertNewNode(SSAGraph* graph,
// set op desc
cpp::OpDesc op_desc = *quantized_op->stmt()->op_info();
op_desc.SetInput("Input", {quantized_op_input->arg()->name});
op_desc.SetOutput("Output", {dequant_op_out->arg()->name});
if (quantized_op_type_ == "conv2d" ||
quantized_op_type_ == "depthwise_conv2d") {
op_desc.SetInput("Input", {quantized_op_input->arg()->name});
op_desc.SetOutput("Output", {dequant_op_out->arg()->name});
} else if (quantized_op_type_ == "mul" || quantized_op_type_ == "matmul") {
op_desc.SetInput("X", {quantized_op_input->arg()->name});
op_desc.SetOutput("Out", {dequant_op_out->arg()->name});
}
op_desc.SetAttr("enable_int8", true);
op_desc.SetAttr("input_scale", input_scale);
op_desc.SetAttr("weight_scale", weight_scale);
// change the weight from the float type to int8 type.
......@@ -297,167 +309,65 @@ cpp::OpDesc ChannelWiseDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
void DeleteQuantDequantOpFuser::BuildPattern() {
std::string quant_dequant_op_type =
"fake_quantize_dequantize_moving_average_abs_max";
if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") {
auto* input_scale_node =
VarNode("input_scale_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_node = VarNode("input_act_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_node =
OpNode("quant_dequant_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_node =
VarNode("output_scale_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_node =
VarNode("output_act_node")
->assert_is_op_output(quant_dequant_op_type, "Out");
auto* quantized_node = OpNode("quantized_node", quantized_op_type_)
->assert_is_op(quantized_op_type_);
quant_dequant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_dequant_node});
output_act_node->LinksFrom({quant_dequant_node});
quantized_node->LinksFrom({output_act_node});
} else if (quantized_op_type_ == "elementwise_add") {
auto* input_scale_left_node =
VarNode("input_scale_left_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_left_node =
VarNode("input_act_left_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_left_node =
OpNode("quant_dequant_left_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_left_node =
VarNode("output_scale_left_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_left_node =
VarNode("output_act_left_node")
->assert_is_op_output(quant_dequant_op_type, "Out")
->assert_is_op_input(quantized_op_type_, "X");
quant_dequant_left_node->LinksFrom(
{input_scale_left_node, input_act_left_node});
output_scale_left_node->LinksFrom({quant_dequant_left_node});
output_act_left_node->LinksFrom({quant_dequant_left_node});
auto* input_scale_right_node =
VarNode("input_scale_right_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_right_node =
VarNode("input_act_right_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_right_node =
OpNode("quant_dequant_right_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_right_node =
VarNode("output_scale_right_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_right_node =
VarNode("output_act_right_node")
->assert_is_op_output(quant_dequant_op_type, "Out")
->assert_is_op_input(quantized_op_type_, "Y");
quant_dequant_right_node->LinksFrom(
{input_scale_right_node, input_act_right_node});
output_scale_right_node->LinksFrom({quant_dequant_right_node});
output_act_right_node->LinksFrom({quant_dequant_right_node});
auto* quantized_node = OpNode("quantized_node", quantized_op_type_)
->assert_is_op(quantized_op_type_);
quantized_node->LinksFrom({output_act_left_node, output_act_right_node});
} else {
LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_;
}
VLOG(4) << "DeleteQuantDequantOpFuser BuildPattern op_type:"
<< quantized_op_type_;
auto* input_scale_node =
VarNode("input_scale_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_node =
VarNode("input_act_node")->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_node = OpNode("quant_dequant_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_node =
VarNode("output_scale_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_node =
VarNode("output_act_node")
->assert_is_op_output(quant_dequant_op_type, "Out");
quant_dequant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_dequant_node});
output_act_node->LinksFrom({quant_dequant_node});
}
void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
if (quantized_op_type_ == "pool2d" || quantized_op_type_ == "softmax") {
auto* input_scale_node = matched.at("input_scale_node");
auto* input_act_node = matched.at("input_act_node");
auto* quant_dequant_node = matched.at("quant_dequant_node");
auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node");
auto* quantized_node = matched.at("quantized_node");
// obtain values, save values and relink node
int bit_length =
quant_dequant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_dequant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("input_scale", scale_value);
op_desc->SetInput("X", {input_act_node->arg()->name});
IR_NODE_LINK_TO(input_act_node, quantized_node)
auto update_op_desc = *quantized_node->stmt()->mutable_op_info();
quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places());
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale_node,
quant_dequant_node,
output_scale_node,
output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm);
} else if (quantized_op_type_ == "elementwise_add") {
auto* input_scale_left_node = matched.at("input_scale_left_node");
auto* input_act_left_node = matched.at("input_act_left_node");
auto* quant_dequant_left_node = matched.at("quant_dequant_left_node");
auto* output_scale_left_node = matched.at("output_scale_left_node");
auto* output_act_left_node = matched.at("output_act_left_node");
auto* input_scale_right_node = matched.at("input_scale_right_node");
auto* input_act_right_node = matched.at("input_act_right_node");
auto* quant_dequant_right_node = matched.at("quant_dequant_right_node");
auto* output_scale_right_node = matched.at("output_scale_right_node");
auto* output_act_right_node = matched.at("output_act_right_node");
auto* quantized_node = matched.at("quantized_node");
// obtain values, save values and relink node
int bit_length =
quant_dequant_left_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_dequant_left_node->stmt()->op()->scope();
auto* left_scale_tensor =
scope->FindVar(output_scale_left_node->arg()->name)
->GetMutable<lite::Tensor>();
float left_scale_value = left_scale_tensor->data<float>()[0] / range;
auto* right_scale_tensor =
scope->FindVar(output_scale_right_node->arg()->name)
->GetMutable<lite::Tensor>();
float right_scale_value = right_scale_tensor->data<float>()[0] / range;
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("x_input_scale", left_scale_value);
op_desc->SetAttr<float>("y_input_scale", right_scale_value);
op_desc->SetInput("X", {input_act_left_node->arg()->name});
op_desc->SetInput("Y", {input_act_right_node->arg()->name});
IR_NODE_LINK_TO(input_act_left_node, quantized_node)
IR_NODE_LINK_TO(input_act_right_node, quantized_node)
auto update_op_desc = *quantized_node->stmt()->mutable_op_info();
quantized_node->stmt()->ResetOp(update_op_desc, graph->valid_places());
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale_left_node,
quant_dequant_left_node,
output_scale_left_node,
output_act_left_node,
input_scale_right_node,
quant_dequant_right_node,
output_scale_right_node,
output_act_right_node};
GraphSafeRemoveNodes(graph, nodes2rm);
} else {
LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_;
auto* input_scale_node = matched.at("input_scale_node");
auto* input_act_node = matched.at("input_act_node");
auto* quant_dequant_node = matched.at("quant_dequant_node");
auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node");
auto input_act_name = input_act_node->arg()->name;
auto output_act_name = output_act_node->arg()->name;
// Get scale value from scale var node
int bit_length =
quant_dequant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_dequant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto quantized_nodes = output_act_node->outlinks;
for (auto* quantized_node : quantized_nodes) {
// Save quantization info in op_info attr
auto op_info = *quantized_node->stmt()->op_info();
std::string argname;
int index;
op_info.GetInputArgname(output_act_name, &argname);
op_info.GetInputIndex(output_act_name, &index);
op_info.SetAttr<float>(argname + std::to_string(index) + "_input_scale",
scale_value);
op_info.SetAttr<float>("input_scale", scale_value); // Save it for now
op_info.SetAttr<int>("bit_length", bit_length);
op_info.UpdateAllInputs(output_act_name, input_act_name);
quantized_node->stmt()->ResetOp(op_info, graph->valid_places());
IR_NODE_LINK_TO(input_act_node, quantized_node);
}
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, quant_dequant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm);
}
cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
......
......@@ -87,24 +87,16 @@ class ChannelWiseDequantOpFuser : public FuseBase {
};
/* The pattern like "fake_quantize_dequantize_moving_average_abs_max +
* pooled/elementwise_add" can be deteted by this fuser. The fuser
* extract the input_scale form fake_quant_dequant_op and save into
* the quantized_op. Besides, the fuser delete fake_quant_dequant_op in
* the graph.
* quantized_op" can be deteted by this fuser. The fuser modifies the input
* scale for the quantized_op and deletes the fake_quant_dequant_op.
*/
class DeleteQuantDequantOpFuser : public FuseBase {
public:
explicit DeleteQuantDequantOpFuser(const std::string& quantized_op_type)
: quantized_op_type_(quantized_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quantized_op_type_{};
};
} // namespace fusion
......
......@@ -225,6 +225,32 @@ class OpInfo : public cpp::OpDesc {
return false;
}
// For the input variable name, find the index of the corresponding
// input argname
bool GetInputIndex(const std::string &value_name, int *out) const {
for (auto &item : inputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
return true;
}
}
return false;
}
// For the output variable name, find the index of the corresponding
// output argname
bool GetOutputIndex(const std::string &value_name, int *out) const {
for (auto &item : outputs_) {
auto it = std::find(item.second.begin(), item.second.end(), value_name);
if (it != item.second.end()) {
*out = it - item.second.begin();
return true;
}
}
return false;
}
void UpdateAllInputs(const std::string &from, const std::string &to) {
for (auto &item : inputs_) {
for (auto &var : item.second) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册