未验证 提交 5cda6b2b 编写于 作者: W Wangzheee 提交者: GitHub

fix: delete_quant_dequant_filter_op_pass, delete_quant_dequant_op_pass (#35879)

上级 1238115e
...@@ -92,7 +92,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -92,7 +92,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
int range = ((1 << (bit_length - 1)) - 1); int range = ((1 << (bit_length - 1)) - 1);
std::vector<float> weight_scale; std::vector<float> weight_scale;
std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name(); std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name();
auto* any_op2_desc = any_op2->Op(); auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs(); auto var_map = any_op2_desc->Inputs();
std::string arg_name = ""; std::string arg_name = "";
...@@ -106,43 +105,52 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -106,43 +105,52 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument( PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument(
"can not find the input %s.", "can not find the input %s.",
quant_dequant_op_out_name)); quant_dequant_op_out_name));
any_op2_desc->SetAttr("enable_int8", true); // any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc->SetAttr("bit_length", bit_length); any_op2_desc->SetAttr("bit_length", bit_length);
// modify the any_op2's inputs // modify the any_op2's inputs
any_op2_desc->Flush();
auto dequant_type = quant_dequant_op->Op()->Type(); auto dequant_type = quant_dequant_op->Op()->Type();
auto quantized_op_type = any_op2_desc->Type();
// get weight tensor // get weight tensor
auto* weight_tensor = auto* weight_tensor =
scope->GetVar(quant_dequant_op_x->Name())->GetMutable<LoDTensor>(); scope->GetVar(quant_dequant_op_x->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims(); auto w_dims = weight_tensor->dims();
float* quantized_weight_data = float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace()); weight_tensor->mutable_data<float>(platform::CPUPlace());
// Get weight scale // Get weight scale
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") { if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
auto scales_name = quant_dequant_op->Op()->Output("OutScale"); int quant_axis =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
// To Do @Wangzheee: use "OutScale" to quantdequant
/*auto scales_name = quant_dequant_op->Op()->Output("OutScale");
PADDLE_ENFORCE_EQ(scales_name.size(), 1, PADDLE_ENFORCE_EQ(scales_name.size(), 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scales size in channel-wise quant dequantize op " "Scales size in channel-wise quant dequantize op "
"should be 1, got %d.", "should be 1, got %d.",
scales_name.size())); scales_name.size()));
const LoDTensor& channel_scale_tensor = const LoDTensor& channel_scale_tensor =
scope->GetVar(scales_name[0])->Get<LoDTensor>(); scope->FindVar(scales_name[0])->Get<LoDTensor>();
PADDLE_ENFORCE( PADDLE_ENFORCE(
paddle::platform::is_cpu_place(channel_scale_tensor.place()), paddle::platform::is_cpu_place(channel_scale_tensor.place()),
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Channel scale tensor's place should be CPU.")); "Channel scale tensor's place should be CPU."));
// compute the channel wise abs max of the weight tensor // compute the channel wise abs max of the weight tensor
int quant_axis =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true, const float* channel_scale_data = channel_scale_tensor.data<float>();
platform::errors::InvalidArgument( for (int i = 0; i < channel_scale_tensor.numel(); i++) {
"'quant_axis' should be 0 or 1, but " weight_scale.push_back(channel_scale_data[i] );
"the received is %d", }*/
quant_axis));
// Implement channel_wise_quantize_dequantize_abs_max quantization
// algorithm
const int64_t channel = w_dims[quant_axis]; const int64_t channel = w_dims[quant_axis];
weight_scale.resize(channel, 0); weight_scale.resize(channel, 0);
if (quant_axis == 0) { if (quant_axis == 0) {
...@@ -171,11 +179,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -171,11 +179,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NE(weight_scale[i], 0, PADDLE_ENFORCE_NE(weight_scale[i], 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero.")); "Weight scale should be nonzero, but get zero."));
weight_scale[i] = range / weight_scale[i]; weight_scale[i] = weight_scale[i] / range;
} }
} else { } else {
auto scale_name = quant_dequant_op_outscale->Name(); // Implement quantize_dequantize_abs_max quantization algorithm
// compute the abs max of the weight tensor
float abs_max_weight = 0.; float abs_max_weight = 0.;
for (int j = 0; j < weight_tensor->numel(); j++) { for (int j = 0; j < weight_tensor->numel(); j++) {
abs_max_weight = abs_max_weight =
...@@ -184,113 +191,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -184,113 +191,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NE(abs_max_weight, 0, PADDLE_ENFORCE_NE(abs_max_weight, 0,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero")); "Weight scale should be nonzero, but get zero"));
weight_scale.push_back((range * range) / abs_max_weight / range); weight_scale.push_back(abs_max_weight / range);
} }
nodes2rm.insert(quant_dequant_op_outscale); nodes2rm.insert(quant_dequant_op_outscale);
// perform quantize dequantize operations
// If quantized op is not channel wise, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if (dequant_type == "fake_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), 1,
platform::errors::InvalidArgument(
"%s op weight dequantized by [fake_quantize_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
quantized_op_type, weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[0];
}
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "fc") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"mul op weight dequantized by "
"[fake_channel_wise_quantize_dequantize_abs_max] requires "
"weight scale "
"size = 2nd dim of mul's weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(
weight_scale[j % w_dims[1]], 0,
platform::errors::InvalidArgument(
"fc op weight scale should be nonzero, but get zero"));
quantized_weight_data[j] =
quantized_weight_data[j] * weight_scale[j % w_dims[1]];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[j % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d op requires weight scale size = channel size of the "
"weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[0]), weight_scale.size()));
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(
weight_scale[j / inner_size], 0,
platform::errors::InvalidArgument(
"conv2d op weight scale should be nonzero, but get zero"));
quantized_weight_data[j] =
quantized_weight_data[j] * weight_scale[j / inner_size];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[j / inner_size];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else if (quantized_op_type == "conv2d_transpose") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d_transpose op requires weight scale size = channel size "
"of the "
"weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
int inner_size = w_dims[2] * w_dims[3];
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(weight_scale[(j / inner_size) % w_dims[1]], 0,
platform::errors::InvalidArgument(
"conv2d_transpose op weight scale should be "
"nonzero, but get zero"));
quantized_weight_data[j] = quantized_weight_data[j] *
weight_scale[(j / inner_size) % w_dims[1]];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /=
weight_scale[(j / inner_size) % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
nodes2rm.insert(quant_dequant_op_out); nodes2rm.insert(quant_dequant_op_out);
// link weight in quant_dequant_op_x to any_op2 // link weight in quant_dequant_op_x to any_op2
......
...@@ -28,76 +28,85 @@ namespace ir { ...@@ -28,76 +28,85 @@ namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); #define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \ #define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(quant_dequant_op_inscale); \ GET_IR_NODE(quant_dequant_op_inscale); \
GET_IR_NODE(quant_dequant_op); \ GET_IR_NODE(quant_dequant_op); \
GET_IR_NODE(quant_dequant_op_outscale); \ GET_IR_NODE(quant_dequant_op_outscale); \
GET_IR_NODE(quant_dequant_op_out); \ GET_IR_NODE(quant_dequant_op_out);
GET_IR_NODE(any_op2);
void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_quantdequant_op_pattern"; const std::string pattern_name = "delete_quantdequant_op_pattern";
FusePassBase::Init(pattern_name, graph); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
std::string quantdequant_types =
"fake_quantize_dequantize_moving_average_abs_max";
auto* input_node = gpd.mutable_pattern()
->NewNode("input_node")
->assert_is_op_input(quantdequant_types, "X")
->AsInput();
patterns::DeleteQuantDequantOpPattern pattern(gpd.mutable_pattern(), patterns::DeleteQuantDequantOpPattern pattern(gpd.mutable_pattern(),
pattern_name); pattern_name);
pattern(); pattern(input_node, quantdequant_types);
auto* scope = param_scope(); auto* scope = param_scope();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
PADDLE_ENFORCE_EQ(
subgraph.count(input_node), true,
platform::errors::NotFound(
"Input act node(%s) not found in QuantDequantFuse pass.",
input_node->name()));
Node* input = subgraph.at(input_node);
GET_NODES; GET_NODES;
IR_NODE_LINK_TO(any_op_out, any_op2); int bit_length =
std::string any_op_out_name = any_op_out->Var()->Name(); BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length"));
std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name(); int range = ((1 << (bit_length - 1)) - 1);
// Get input scale from tensor
std::string input_scale_var_name = std::string input_scale_var_name =
quant_dequant_op->Op()->Input("InScale").front(); quant_dequant_op->Op()->Input("InScale").front();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument(
"Scope in DeleteQuantDequantOpPass should not be null."));
const LoDTensor& input_scale_tensor = const LoDTensor& input_scale_tensor =
scope->GetVar(input_scale_var_name)->Get<LoDTensor>(); scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true,
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>(); const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0] / 127.; float input_scale = input_scale_data[0] / range;
auto* any_op2_desc = any_op2->Op();
// auto input_args_names = any_op2_desc->InputArgumentNames();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
quant_dequant_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
}
}
CHECK(arg_name.size() > 0) << "can not find the input "
<< quant_dequant_op_out_name;
any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc->SetAttr(arg_name + "_scale", input_scale);
// modify the any_op2's inputs // Set input scale in attr, and relink nodes
for (auto& name_m : var_map) { std::string input_name = input->Var()->Name();
if (std::find(name_m.second.begin(), name_m.second.end(), std::string quant_dequant_output_name = quant_dequant_op_out->Var()->Name();
quant_dequant_op_out_name) != name_m.second.end()) { auto outlinks = quant_dequant_op_out->outputs;
std::vector<std::string> new_inputs; for (auto* quantized_node : outlinks) {
for (auto& i_n : name_m.second) { auto op_desc = quantized_node->Op();
if (i_n != quant_dequant_op_out_name) { std::string quantized_op_type = op_desc->Type();
new_inputs.push_back(i_n); if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
} quantized_op_type == "matmul_v2") {
} op_desc->SetAttr("X_scale", input_scale);
new_inputs.push_back(any_op_out_name); } else {
any_op2_desc->SetInput(name_m.first, new_inputs); op_desc->SetAttr("Input_scale", input_scale);
any_op2_desc->Flush();
} }
op_desc->SetAttr("bit_length", bit_length);
op_desc->RenameInput(quant_dequant_output_name, input_name);
op_desc->Flush();
IR_NODE_LINK_TO(input, quantized_node);
} }
any_op2_desc->Flush();
// Delete the unneeded nodes. // Delete the unneeded nodes.
GraphSafeRemoveNodes(graph, GraphSafeRemoveNodes(graph,
{quant_dequant_op, quant_dequant_op_out, {quant_dequant_op_inscale, quant_dequant_op,
quant_dequant_op_inscale, quant_dequant_op_outscale}); quant_dequant_op_outscale, quant_dequant_op_out});
found_count++;
}; };
gpd(graph, handler); gpd(graph, handler);
AddStatis(found_count);
} }
} // namespace ir } // namespace ir
......
...@@ -2547,39 +2547,28 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) { ...@@ -2547,39 +2547,28 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
reshape2_out->LinksFrom({reshape2_op}); reshape2_out->LinksFrom({reshape2_op});
} }
void patterns::DeleteQuantDequantOpPattern::operator()() { void patterns::DeleteQuantDequantOpPattern::operator()(
auto any_op_out = PDNode *input_node, const std::string &quantdequant_types) {
pattern->NewNode(any_op_out_repr())
->assert_is_op_input(
"fake_quantize_dequantize_moving_average_abs_max", "X")
->AsInput();
auto quant_dequant_op_inscale = auto quant_dequant_op_inscale =
pattern->NewNode(quant_dequant_op_inscale_repr()) pattern->NewNode(quant_dequant_op_inscale_repr())
->assert_is_op_input( ->assert_is_op_input(quantdequant_types, "InScale")
"fake_quantize_dequantize_moving_average_abs_max", "InScale")
->AsInput(); ->AsInput();
auto quant_dequant_op = auto quant_dequant_op = pattern->NewNode(quant_dequant_op_repr())
pattern->NewNode(quant_dequant_op_repr()) ->assert_is_op(quantdequant_types);
->assert_is_op("fake_quantize_dequantize_moving_average_abs_max");
auto quant_dequant_out = auto quant_dequant_op_out =
pattern->NewNode(quant_dequant_op_out_repr()) pattern->NewNode(quant_dequant_op_out_repr())
->assert_is_op_output( ->assert_is_op_output(quantdequant_types, "Out")
"fake_quantize_dequantize_moving_average_abs_max", "Out") ->AsOutput();
->AsIntermediate();
auto quant_dequant_op_outscale = auto quant_dequant_op_outscale =
pattern->NewNode(quant_dequant_op_outscale_repr()) pattern->NewNode(quant_dequant_op_outscale_repr())
->assert_is_op_output( ->assert_is_op_output(quantdequant_types, "OutScale")
"fake_quantize_dequantize_moving_average_abs_max", "OutScale")
->AsOutput(); ->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
quant_dequant_op->LinksFrom({any_op_out, quant_dequant_op_inscale}); quant_dequant_op->LinksFrom({quant_dequant_op_inscale, input_node});
quant_dequant_op_outscale->LinksFrom({quant_dequant_op}); quant_dequant_op_outscale->LinksFrom({quant_dequant_op});
quant_dequant_out->LinksFrom({quant_dequant_op}); quant_dequant_op_out->LinksFrom({quant_dequant_op});
any_op2->LinksFrom({quant_dequant_out});
} }
void patterns::DeleteQuantDequantFilterOpPattern::operator()() { void patterns::DeleteQuantDequantFilterOpPattern::operator()() {
......
...@@ -1481,14 +1481,12 @@ struct DeleteQuantDequantOpPattern : public PatternBase { ...@@ -1481,14 +1481,12 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
DeleteQuantDequantOpPattern(PDPattern* pattern, const std::string& name_scope) DeleteQuantDequantOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_quantdequant_op_pattern") {} : PatternBase(pattern, name_scope, "delete_quantdequant_op_pattern") {}
void operator()(); void operator()(PDNode* input_node, const std::string& quantdequant_types);
PATTERN_DECL_NODE(any_op_out);
PATTERN_DECL_NODE(quant_dequant_op_inscale); PATTERN_DECL_NODE(quant_dequant_op_inscale);
PATTERN_DECL_NODE(quant_dequant_op); PATTERN_DECL_NODE(quant_dequant_op);
PATTERN_DECL_NODE(quant_dequant_op_outscale); PATTERN_DECL_NODE(quant_dequant_op_outscale);
PATTERN_DECL_NODE(quant_dequant_op_out); PATTERN_DECL_NODE(quant_dequant_op_out);
PATTERN_DECL_NODE(any_op2);
}; };
struct DeleteQuantDequantFilterOpPattern : public PatternBase { struct DeleteQuantDequantFilterOpPattern : public PatternBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册