未验证 提交 5cda6b2b 编写于 作者: W Wangzheee 提交者: GitHub

fix: delete_quant_dequant_filter_op_pass, delete_quant_dequant_op_pass (#35879)

上级 1238115e
......@@ -92,7 +92,6 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
int range = ((1 << (bit_length - 1)) - 1);
std::vector<float> weight_scale;
std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name();
auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
......@@ -106,43 +105,52 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument(
"can not find the input %s.",
quant_dequant_op_out_name));
any_op2_desc->SetAttr("enable_int8", true);
// any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc->SetAttr("bit_length", bit_length);
// modify the any_op2's inputs
any_op2_desc->Flush();
auto dequant_type = quant_dequant_op->Op()->Type();
auto quantized_op_type = any_op2_desc->Type();
// get weight tensor
auto* weight_tensor =
scope->GetVar(quant_dequant_op_x->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims();
float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
// Get weight scale
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
auto scales_name = quant_dequant_op->Op()->Output("OutScale");
int quant_axis =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
// To Do @Wangzheee: use "OutScale" to quantdequant
/*auto scales_name = quant_dequant_op->Op()->Output("OutScale");
PADDLE_ENFORCE_EQ(scales_name.size(), 1,
platform::errors::InvalidArgument(
"Scales size in channel-wise quant dequantize op "
"should be 1, got %d.",
scales_name.size()));
const LoDTensor& channel_scale_tensor =
scope->GetVar(scales_name[0])->Get<LoDTensor>();
scope->FindVar(scales_name[0])->Get<LoDTensor>();
PADDLE_ENFORCE(
paddle::platform::is_cpu_place(channel_scale_tensor.place()),
platform::errors::InvalidArgument(
"Channel scale tensor's place should be CPU."));
// compute the channel wise abs max of the weight tensor
int quant_axis =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("quant_axis"));
PADDLE_ENFORCE_EQ(quant_axis == 0 || quant_axis == 1, true,
platform::errors::InvalidArgument(
"'quant_axis' should be 0 or 1, but "
"the received is %d",
quant_axis));
const float* channel_scale_data = channel_scale_tensor.data<float>();
for (int i = 0; i < channel_scale_tensor.numel(); i++) {
weight_scale.push_back(channel_scale_data[i] );
}*/
// Implement channel_wise_quantize_dequantize_abs_max quantization
// algorithm
const int64_t channel = w_dims[quant_axis];
weight_scale.resize(channel, 0);
if (quant_axis == 0) {
......@@ -171,11 +179,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NE(weight_scale[i], 0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero."));
weight_scale[i] = range / weight_scale[i];
weight_scale[i] = weight_scale[i] / range;
}
} else {
auto scale_name = quant_dequant_op_outscale->Name();
// compute the abs max of the weight tensor
// Implement quantize_dequantize_abs_max quantization algorithm
float abs_max_weight = 0.;
for (int j = 0; j < weight_tensor->numel(); j++) {
abs_max_weight =
......@@ -184,113 +191,10 @@ void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NE(abs_max_weight, 0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero"));
weight_scale.push_back((range * range) / abs_max_weight / range);
weight_scale.push_back(abs_max_weight / range);
}
nodes2rm.insert(quant_dequant_op_outscale);
// perform quantize dequantize operations
// If quantized op is not channel wise, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if (dequant_type == "fake_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), 1,
platform::errors::InvalidArgument(
"%s op weight dequantized by [fake_quantize_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
quantized_op_type, weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[0];
}
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "fc") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"mul op weight dequantized by "
"[fake_channel_wise_quantize_dequantize_abs_max] requires "
"weight scale "
"size = 2nd dim of mul's weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(
weight_scale[j % w_dims[1]], 0,
platform::errors::InvalidArgument(
"fc op weight scale should be nonzero, but get zero"));
quantized_weight_data[j] =
quantized_weight_data[j] * weight_scale[j % w_dims[1]];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[j % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d op requires weight scale size = channel size of the "
"weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[0]), weight_scale.size()));
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(
weight_scale[j / inner_size], 0,
platform::errors::InvalidArgument(
"conv2d op weight scale should be nonzero, but get zero"));
quantized_weight_data[j] =
quantized_weight_data[j] * weight_scale[j / inner_size];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[j / inner_size];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else if (quantized_op_type == "conv2d_transpose") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d_transpose op requires weight scale size = channel size "
"of the "
"weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
int inner_size = w_dims[2] * w_dims[3];
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(weight_scale[(j / inner_size) % w_dims[1]], 0,
platform::errors::InvalidArgument(
"conv2d_transpose op weight scale should be "
"nonzero, but get zero"));
quantized_weight_data[j] = quantized_weight_data[j] *
weight_scale[(j / inner_size) % w_dims[1]];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /=
weight_scale[(j / inner_size) % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
nodes2rm.insert(quant_dequant_op_out);
// link weight in quant_dequant_op_x to any_op2
......
......@@ -28,76 +28,85 @@ namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(any_op_out); \
GET_IR_NODE(quant_dequant_op_inscale); \
GET_IR_NODE(quant_dequant_op); \
GET_IR_NODE(quant_dequant_op_outscale); \
GET_IR_NODE(quant_dequant_op_out); \
GET_IR_NODE(any_op2);
GET_IR_NODE(quant_dequant_op_out);
void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_quantdequant_op_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
std::string quantdequant_types =
"fake_quantize_dequantize_moving_average_abs_max";
auto* input_node = gpd.mutable_pattern()
->NewNode("input_node")
->assert_is_op_input(quantdequant_types, "X")
->AsInput();
patterns::DeleteQuantDequantOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern();
pattern(input_node, quantdequant_types);
auto* scope = param_scope();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
PADDLE_ENFORCE_EQ(
subgraph.count(input_node), true,
platform::errors::NotFound(
"Input act node(%s) not found in QuantDequantFuse pass.",
input_node->name()));
Node* input = subgraph.at(input_node);
GET_NODES;
IR_NODE_LINK_TO(any_op_out, any_op2);
std::string any_op_out_name = any_op_out->Var()->Name();
std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name();
int bit_length =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
// Get input scale from tensor
std::string input_scale_var_name =
quant_dequant_op->Op()->Input("InScale").front();
PADDLE_ENFORCE_NOT_NULL(
scope, platform::errors::InvalidArgument(
"Scope in DeleteQuantDequantOpPass should not be null."));
const LoDTensor& input_scale_tensor =
scope->GetVar(input_scale_var_name)->Get<LoDTensor>();
scope->FindVar(input_scale_var_name)->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true,
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0] / 127.;
auto* any_op2_desc = any_op2->Op();
// auto input_args_names = any_op2_desc->InputArgumentNames();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
quant_dequant_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
}
}
CHECK(arg_name.size() > 0) << "can not find the input "
<< quant_dequant_op_out_name;
any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc->SetAttr(arg_name + "_scale", input_scale);
float input_scale = input_scale_data[0] / range;
// modify the any_op2's inputs
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
quant_dequant_op_out_name) != name_m.second.end()) {
std::vector<std::string> new_inputs;
for (auto& i_n : name_m.second) {
if (i_n != quant_dequant_op_out_name) {
new_inputs.push_back(i_n);
}
}
new_inputs.push_back(any_op_out_name);
any_op2_desc->SetInput(name_m.first, new_inputs);
any_op2_desc->Flush();
// Set input scale in attr, and relink nodes
std::string input_name = input->Var()->Name();
std::string quant_dequant_output_name = quant_dequant_op_out->Var()->Name();
auto outlinks = quant_dequant_op_out->outputs;
for (auto* quantized_node : outlinks) {
auto op_desc = quantized_node->Op();
std::string quantized_op_type = op_desc->Type();
if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
op_desc->SetAttr("X_scale", input_scale);
} else {
op_desc->SetAttr("Input_scale", input_scale);
}
op_desc->SetAttr("bit_length", bit_length);
op_desc->RenameInput(quant_dequant_output_name, input_name);
op_desc->Flush();
IR_NODE_LINK_TO(input, quantized_node);
}
any_op2_desc->Flush();
// Delete the unneeded nodes.
GraphSafeRemoveNodes(graph,
{quant_dequant_op, quant_dequant_op_out,
quant_dequant_op_inscale, quant_dequant_op_outscale});
{quant_dequant_op_inscale, quant_dequant_op,
quant_dequant_op_outscale, quant_dequant_op_out});
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir
......
......@@ -2547,39 +2547,28 @@ void patterns::ShuffleChannelPattern::operator()(PDNode *reshape1_in) {
reshape2_out->LinksFrom({reshape2_op});
}
void patterns::DeleteQuantDequantOpPattern::operator()() {
auto any_op_out =
pattern->NewNode(any_op_out_repr())
->assert_is_op_input(
"fake_quantize_dequantize_moving_average_abs_max", "X")
->AsInput();
void patterns::DeleteQuantDequantOpPattern::operator()(
PDNode *input_node, const std::string &quantdequant_types) {
auto quant_dequant_op_inscale =
pattern->NewNode(quant_dequant_op_inscale_repr())
->assert_is_op_input(
"fake_quantize_dequantize_moving_average_abs_max", "InScale")
->assert_is_op_input(quantdequant_types, "InScale")
->AsInput();
auto quant_dequant_op =
pattern->NewNode(quant_dequant_op_repr())
->assert_is_op("fake_quantize_dequantize_moving_average_abs_max");
auto quant_dequant_op = pattern->NewNode(quant_dequant_op_repr())
->assert_is_op(quantdequant_types);
auto quant_dequant_out =
auto quant_dequant_op_out =
pattern->NewNode(quant_dequant_op_out_repr())
->assert_is_op_output(
"fake_quantize_dequantize_moving_average_abs_max", "Out")
->AsIntermediate();
->assert_is_op_output(quantdequant_types, "Out")
->AsOutput();
auto quant_dequant_op_outscale =
pattern->NewNode(quant_dequant_op_outscale_repr())
->assert_is_op_output(
"fake_quantize_dequantize_moving_average_abs_max", "OutScale")
->assert_is_op_output(quantdequant_types, "OutScale")
->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
quant_dequant_op->LinksFrom({any_op_out, quant_dequant_op_inscale});
quant_dequant_op->LinksFrom({quant_dequant_op_inscale, input_node});
quant_dequant_op_outscale->LinksFrom({quant_dequant_op});
quant_dequant_out->LinksFrom({quant_dequant_op});
any_op2->LinksFrom({quant_dequant_out});
quant_dequant_op_out->LinksFrom({quant_dequant_op});
}
void patterns::DeleteQuantDequantFilterOpPattern::operator()() {
......
......@@ -1481,14 +1481,12 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
DeleteQuantDequantOpPattern(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "delete_quantdequant_op_pattern") {}
void operator()();
void operator()(PDNode* input_node, const std::string& quantdequant_types);
PATTERN_DECL_NODE(any_op_out);
PATTERN_DECL_NODE(quant_dequant_op_inscale);
PATTERN_DECL_NODE(quant_dequant_op);
PATTERN_DECL_NODE(quant_dequant_op_outscale);
PATTERN_DECL_NODE(quant_dequant_op_out);
PATTERN_DECL_NODE(any_op2);
};
struct DeleteQuantDequantFilterOpPattern : public PatternBase {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册