未验证 提交 95332bef 编写于 作者: R RichardWooSJTU 提交者: GitHub

rewrite delete_weight_dequant_linear_op_encoder/decoder pass (#48650)

* rewrite delete_weight_deqquant_linear_op_encoder/decoder pass
上级 a14ae84b
...@@ -95,9 +95,8 @@ pass_library(quant_conv2d_dequant_fuse_pass inference) ...@@ -95,9 +95,8 @@ pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(shuffle_channel_detect_pass inference) pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_op_pass inference)
pass_library(delete_quant_dequant_filter_op_pass inference) pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(trt_delete_weight_dequant_linear_op_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference) pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_weight_dequant_linear_op_encoder_pass inference)
pass_library(delete_weight_dequant_linear_op_decoder_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference) pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference) pass_library(delete_dropout_op_pass inference)
pass_library(delete_c_identity_op_pass inference) pass_library(delete_c_identity_op_pass inference)
...@@ -359,6 +358,10 @@ cc_test( ...@@ -359,6 +358,10 @@ cc_test(
test_delete_dropout_pass_cc test_delete_dropout_pass_cc
SRCS delete_dropout_op_pass_test.cc SRCS delete_dropout_op_pass_test.cc
DEPS delete_dropout_op_pass) DEPS delete_dropout_op_pass)
cc_test(
test_delete_dequant_weight_linear_op_pass
SRCS delete_weight_dequant_linear_op_pass_tester.cc
DEPS delete_weight_dequant_linear_op_pass)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
cc_test( cc_test(
test_embedding_eltwise_layernorm_fuse_pass test_embedding_eltwise_layernorm_fuse_pass
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_decoder_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(weight_dequantize_linear_op_x); \
GET_IR_NODE(weight_dequantize_linear_op_scale); \
GET_IR_NODE(weight_dequantize_linear_op); \
GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteWeightDequantLinearOpDecoderPass::
DeleteWeightDequantLinearOpDecoderPass() {
AddOpCompat(OpCompat("quantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("depthwise_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End();
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
// Delete dequantize_linear_op, then dequantize weight
void DeleteWeightDequantLinearOpDecoderPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name =
"delete_weight_dequant_linear_op_decoder_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(scope,
platform::errors::InvalidArgument(
"Scope in DeleteWeightDequantLinearOpDecoderPass "
"should not be null."));
// Create pattern
patterns::DeleteWeightDequantLinearOpDecoderPattern pattern(
gpd.mutable_pattern(), pattern_name);
pattern();
int found_count = 0;
bool is_int8 = false;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
is_int8 = true;
std::unordered_set<const Node*> nodes2rm = {};
auto* any_op2_desc = any_op2->Op();
// Get weight scale
std::vector<float> weight_scale;
auto* weight_scale_tensor =
scope->GetVar(weight_dequantize_linear_op_scale->Name())
->GetMutable<phi::DenseTensor>();
auto weight_scale_nums = weight_scale_tensor->numel();
if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT32) {
float* weight_scale_data = weight_scale_tensor->data<float>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(weight_scale_data[i]);
}
} else if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT16) {
phi::dtype::float16* weight_scale_data =
weight_scale_tensor->data<phi::dtype::float16>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(static_cast<float>(weight_scale_data[i]));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"%d is not supported.", weight_scale_tensor->dtype()));
}
int quant_axis = PADDLE_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums,
1,
platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."));
// Add attr to anyop 2
any_op2_desc->SetAttr("weight_scale", weight_scale[0]);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Delete Weight Dequant Linear Op Encoder Pass is not supported for "
"per-channel quantization"));
}
nodes2rm.insert(weight_dequantize_linear_op_scale);
nodes2rm.insert(weight_dequantize_linear_op);
nodes2rm.insert(weight_dequantize_linear_op_out);
// relink weight to any_op2
any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(),
weight_dequantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm);
found_count++;
};
gpd(graph, handler);
if (is_int8) {
auto& enable_int8 = graph->Get<bool>("enable_int8");
enable_int8 = true;
}
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_weight_dequant_linear_op_decoder_pass,
paddle::framework::ir::DeleteWeightDequantLinearOpDecoderPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class DeleteWeightDequantLinearOpEncoderPass : public FusePassBase {
public:
DeleteWeightDequantLinearOpEncoderPass();
virtual ~DeleteWeightDequantLinearOpEncoderPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h" #include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include <algorithm> #include "glog/logging.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); class Graph;
#define GET_NODES \
GET_IR_NODE(weight_dequantize_linear_op_x); \ void DeleteWeightDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
GET_IR_NODE(weight_dequantize_linear_op_scale); \ std::unordered_set<std::string> op_list = {"matmul_v2",
GET_IR_NODE(weight_dequantize_linear_op); \ "matmul",
GET_IR_NODE(weight_dequantize_linear_op_out); \ "mul",
GET_IR_NODE(any_op2); "fc",
"depthwise_conv2d",
DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() { "conv2d",
AddOpCompat(OpCompat("quantize_linear")) "conv2d_transpose"};
.AddInput("X") PADDLE_ENFORCE_EQ(graph->Has(kParamScopeAttr),
.IsTensor() true,
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End()
.AddAttr("round_type")
.IsOptional()
.IsType<int>()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("depthwise_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End();
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
// Delete dequantize_linear_op, then dequantize weight
void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name =
"delete_weight_quantdequant_linear_op_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scope in DeleteWeightQuantDequantLinearOpPass should not be null.")); "Graph must have kParamScopeAttr attribute."));
// Create pattern
patterns::DeleteWeightQuantDequantLinearOpPattern pattern( auto& scope = graph->Get<framework::Scope>(kParamScopeAttr);
gpd.mutable_pattern(), pattern_name); bool is_int8 = false;
pattern();
int found_count = 0; std::unordered_set<const Node*> nodes2rm;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, for (const Node* n : graph->Nodes()) {
Graph* g) { if (n->IsOp()) {
GET_NODES; auto* op = n->Op();
/* if (op->Type() == "dequantize_linear") {
if (!IsCompat(subgraph, g)) { Node *weight_var_node, *dequantized_weight_var_node, *scale_var_node,
LOG(WARNING) << "delete_weight_dequant_linear_op_pass " *calcu_op_node, *while_op_node;
"compat check failed."; // 1. Judge whether for dequant weight and find
// weight_var_node/scale_var_node
for (auto* input_node : n->inputs) {
if (input_node->IsVar() && input_node->Var()->Persistable()) {
is_int8 = true;
if (input_node->Var()->Name() == op->Input("X")[0]) {
weight_var_node = input_node;
} else if (input_node->Var()->Name() == op->Input("Scale")[0]) {
scale_var_node = input_node;
}
} else {
return; return;
} }
*/ }
std::unordered_set<const Node*> nodes2rm = {}; // 2. Find next_op_node
int bit_length = PADDLE_GET_CONST( // For while op: delete its input which is related to dequantized
int, weight_dequantize_linear_op->Op()->GetAttr("bit_length")); // For calculation op: set weight scale as their attributes
int range = ((1 << (bit_length - 1)) - 1); for (auto* output_node : n->outputs) {
if (output_node->IsVar() &&
auto* any_op2_desc = any_op2->Op(); output_node->Var()->Name() == op->Output("Y")[0]) {
dequantized_weight_var_node = output_node;
// get weight tensor for (auto* next_op_node : output_node->outputs) {
auto* weight_tensor = scope->GetVar(weight_dequantize_linear_op_x->Name()) if (next_op_node->IsOp()) {
->GetMutable<phi::DenseTensor>(); if (next_op_node->Op()->Type() == "while") {
int8_t* quantized_weight_data = while_op_node = next_op_node;
weight_tensor->mutable_data<int8_t>(platform::CPUPlace()); auto while_op_desc = while_op_node->Op();
auto w_dims = weight_tensor->dims(); auto while_Xs = while_op_desc->Input("X");
while_Xs.erase(std::remove(std::begin(while_Xs),
std::end(while_Xs),
output_node->Var()->Name()),
std::end(while_Xs));
while_op_node->Op()->SetInput("X", while_Xs);
} else if (op_list.count(next_op_node->Op()->Type()) != 0) {
calcu_op_node = next_op_node;
auto* calcu_op_desc = calcu_op_node->Op();
// Get weight scale
std::vector<float> weight_scale; std::vector<float> weight_scale;
auto* weight_scale_tensor = auto* weight_scale_tensor =
scope->GetVar(weight_dequantize_linear_op_scale->Name()) scope.GetVar(scale_var_node->Name())
->GetMutable<phi::DenseTensor>(); ->GetMutable<phi::DenseTensor>();
float* weight_scale_data =
weight_scale_tensor->mutable_data<float>(platform::CPUPlace());
auto weight_scale_nums = weight_scale_tensor->numel(); auto weight_scale_nums = weight_scale_tensor->numel();
if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT32) {
float* weight_scale_data =
weight_scale_tensor->data<float>();
for (int i = 0; i < weight_scale_nums; i++) { for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(weight_scale_data[i] / range); weight_scale.push_back(weight_scale_data[i]);
}
} else if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT16) {
phi::dtype::float16* weight_scale_data =
weight_scale_tensor->data<phi::dtype::float16>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(
static_cast<float>(weight_scale_data[i]));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"The dtype of quantization scale must be FP32/16, "
"but received %d, which is not supported.",
weight_scale_tensor->dtype()));
} }
// dequant weight int quant_axis =
std::vector<float> weight_data_tmp; PADDLE_GET_CONST(int, op->GetAttr("quant_axis"));
weight_data_tmp.reserve(weight_tensor->numel());
int quant_axis = PADDLE_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums,
1,
platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."));
// float(weight) * scale
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data_tmp[i] =
static_cast<float>(quantized_weight_data[i]) * weight_scale[0];
}
} else if (quant_axis == 0) { // per_channel quant_dequant: conv2d,
// depthwise_conv2d, conv2d_fusion
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
weight_scale_nums, weight_scale_nums,
w_dims[quant_axis], 1,
platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));
PADDLE_ENFORCE_EQ(w_dims.size(),
4,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel " "When quant_axis == -1, it means using per_layer "
"quant_dequant, (conv2d, depthwise_conv2d, " "dequantization. In this situation, the number of "
"conv2d_fusion)'s weight dims should be 4.")); "weight_scale should be 1, but received %d.",
weight_scale_nums));
for (int i = 0; i < weight_tensor->numel(); i++) { calcu_op_desc->SetAttr("weight_scale", weight_scale[0]);
int inner_size = w_dims[1] * w_dims[2] * w_dims[3]; } else {
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) * PADDLE_THROW(platform::errors::Unimplemented(
weight_scale[i / inner_size]; "Delete Weight Dequant Linear Op Pass is not supported "
"for "
"per-channel quantization"));
}
calcu_op_desc->RenameInput(
dequantized_weight_var_node->Var()->Name(),
weight_var_node->Var()->Name());
} }
} else if (quant_axis == 1) {
PADDLE_ENFORCE_EQ(
weight_scale_nums,
w_dims[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));
if (w_dims.size() == 4) { // conv2d_transpose
std::string quantized_op_type = any_op2->Op()->Type();
PADDLE_ENFORCE_EQ(
quantized_op_type,
"conv2d_transpose",
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"only conv2d_transpose weight dims equal 4."));
for (int i = 0; i < weight_tensor->numel(); i++) {
int inner_size = w_dims[2] * w_dims[3];
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) *
weight_scale[(i / inner_size) % w_dims[1]];
} }
} else if (w_dims.size() == 2) {
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) *
weight_scale[i % w_dims[1]];
} }
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When quant_axis == 1 , weight dims should be 2 or 4, please check "
"your model "));
} }
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"quant_axis should be -1 or 0 or 1, please check your model "
"OP'attribute "));
} }
weight_tensor->clear(); // clear int weight
weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims)));
float* new_quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_quantized_weight_data,
weight_data_tmp.data(),
weight_tensor->numel() * sizeof(float));
nodes2rm.insert(weight_dequantize_linear_op_scale); // 3. Delete dequant op
nodes2rm.insert(weight_dequantize_linear_op); IR_NODE_LINK_TO(weight_var_node, calcu_op_node);
nodes2rm.insert(weight_dequantize_linear_op_out); std::vector<const Node*> nodes2rm_local{
dequantized_weight_var_node, scale_var_node, n};
for (auto* node2rm : nodes2rm_local) {
if (node2rm) {
nodes2rm.insert(node2rm);
}
}
}
}
}
// relink weight to any_op2
any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(),
weight_dequantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm); GraphSafeRemoveNodes(graph, nodes2rm);
found_count++; graph->Set("enable_int8", new bool(is_int8));
};
gpd(graph, handler);
AddStatis(found_count);
} }
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(delete_weight_dequant_linear_op_pass, REGISTER_PASS(delete_weight_dequant_linear_op_pass,
paddle::framework::ir::DeleteWeightQuantDequantLinearOpPass); paddle::framework::ir::DeleteWeightDequantLinearOpPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
// You may obtain a copy of the License at You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0 http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS, distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and See the License for the specific language governing permissions and
// limitations under the License. limitations under the License. */
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/pass.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class DeleteWeightQuantDequantLinearOpPass : public FusePassBase { class Graph;
public:
DeleteWeightQuantDequantLinearOpPass();
virtual ~DeleteWeightQuantDequantLinearOpPass() {}
class DeleteWeightDequantLinearOpPass : public Pass {
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
}; };
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
template <typename T>
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
dev_ctx->HostAlloc<T>(tensor, tensor->numel() * sizeof(T));
}
template <typename T>
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope<T>(param_scope, "scale", {1});
return param_scope;
}
TEST(DeleteWeightDequantLinearOpPass, basic) {
// inputs operator output
// --------------------------------------------------------------------
// (weight, scale) dequantize_linear -> dequantized_weight
// (x, dequantized_weight) matmul/fc/conv -> matmul_out
// (dequantized_weight) while -> [optional]
Layers layers;
auto* x = layers.data("x", {1, 128, 768});
auto* weight = layers.data("weight", {768, 768}, true);
auto* scale = layers.data("scale", {1}, true);
auto* zero_point = layers.data("zero_point", {1}, true);
auto* dequantized_weight =
layers.dequantize_linear(weight, scale, zero_point);
layers.matmul_v2(x, dequantized_weight);
layers.while_loop({dequantized_weight});
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope<float>());
auto pass =
PassRegistry::Instance().Get("delete_weight_dequant_linear_op_pass");
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_dequant_nodes_after = GetNumOpNodes(graph, "dequantize_linear");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 3,
platform::errors::InvalidArgument(
"After pass, the number of nodes should be reduced by 3, but the "
"number before pass is %d, after pass is %d.",
num_nodes_before,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_dequant_nodes_after,
0,
platform::errors::InvalidArgument(
"After pass, the number of nodes of type "
"'dequantize_linear' should be 1, not %d.",
num_dequant_nodes_after));
}
TEST(DeleteWeightDequantLinearOpPass, basic_fp16) {
// inputs operator output
// --------------------------------------------------------------------
// (weight, scale) dequantize_linear -> dequantized_weight
// (x, dequantized_weight) matmul/fc/conv -> matmul_out
// (dequantized_weight) while -> [optional]
Layers layers;
auto* x = layers.data("x", {1, 128, 768});
auto* weight = layers.data("weight", {768, 768}, true);
auto* scale = layers.data("scale", {1}, true);
auto* zero_point = layers.data("zero_point", {1}, true);
auto* dequantized_weight =
layers.dequantize_linear(weight, scale, zero_point);
layers.matmul_v2(x, dequantized_weight);
layers.while_loop({dequantized_weight});
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope<phi::dtype::float16>());
auto pass =
PassRegistry::Instance().Get("delete_weight_dequant_linear_op_pass");
int num_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
graph.reset(pass->Apply(graph.release()));
int num_nodes_after = graph->Nodes().size();
int num_dequant_nodes_after = GetNumOpNodes(graph, "dequantize_linear");
VLOG(3) << DebugString(graph);
PADDLE_ENFORCE_EQ(
num_nodes_before,
num_nodes_after + 3,
platform::errors::InvalidArgument(
"After pass, the number of nodes should be reduced by 3, but the "
"number before pass is %d, after pass is %d.",
num_nodes_before,
num_nodes_after));
PADDLE_ENFORCE_EQ(num_dequant_nodes_after,
0,
platform::errors::InvalidArgument(
"After pass, the number of nodes of type "
"'dequantize_linear' should be 1, not %d.",
num_dequant_nodes_after));
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(delete_weight_dequant_linear_op_pass);
...@@ -48,8 +48,8 @@ static const std::vector<std::string> support_subgraph_passes = { ...@@ -48,8 +48,8 @@ static const std::vector<std::string> support_subgraph_passes = {
"multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass", "multi_devices_fused_multi_transformer_decoder_fuse_qkv_pass",
"fuse_multi_transformer_layer_pass", "fuse_multi_transformer_layer_pass",
"delete_quant_dequant_linear_op_pass", "delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_encoder_pass", "delete_weight_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_decoder_pass"}; };
Graph *Pass::Apply(Graph *graph) const { Graph *Pass::Apply(Graph *graph) const {
VLOG(10) << "start to apply pass " << Type() << " to graph"; VLOG(10) << "start to apply pass " << Type() << " to graph";
......
...@@ -641,6 +641,23 @@ struct Layers { ...@@ -641,6 +641,23 @@ struct Layers {
return out; return out;
} }
VarDesc* dequantize_linear(VarDesc* x,
VarDesc* scale,
VarDesc* zero_point,
int bit_length = 8,
int quant_axis = -1) {
VarDesc* out = lod_tensor(unique_name());
OpDesc* op = program_.MutableBlock(0)->AppendOp();
op->SetType("dequantize_linear");
op->SetInput("X", {x->Name()});
op->SetInput("Scale", {scale->Name()});
op->SetInput("ZeroPoint", {zero_point->Name()});
op->SetAttr("bit_length", bit_length);
op->SetAttr("quant_axis", quant_axis);
op->SetOutput("Y", {out->Name()});
return out;
}
void backward(std::vector<VarDesc*> targets) { void backward(std::vector<VarDesc*> targets) {
// This function is designed to simulate the structure of training program, // This function is designed to simulate the structure of training program,
// but is constructed differently as the actual program. // but is constructed differently as the actual program.
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_encoder_pass.h" #include "paddle/fluid/framework/ir/trt_delete_weight_dequant_linear_op_pass.h"
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
...@@ -32,8 +32,8 @@ namespace ir { ...@@ -32,8 +32,8 @@ namespace ir {
GET_IR_NODE(weight_dequantize_linear_op_out); \ GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2); GET_IR_NODE(any_op2);
DeleteWeightDequantLinearOpEncoderPass:: TrtDeleteWeightQuantDequantLinearOpPass::
DeleteWeightDequantLinearOpEncoderPass() { TrtDeleteWeightQuantDequantLinearOpPass() {
AddOpCompat(OpCompat("quantize_linear")) AddOpCompat(OpCompat("quantize_linear"))
.AddInput("X") .AddInput("X")
.IsTensor() .IsTensor()
...@@ -270,64 +270,69 @@ DeleteWeightDequantLinearOpEncoderPass:: ...@@ -270,64 +270,69 @@ DeleteWeightDequantLinearOpEncoderPass::
.End(); .End();
} }
// Delete dequantize_linear_op, then dequantize weight // Delete dequantize_linear_op, then dequantize weight
void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const { void TrtDeleteWeightQuantDequantLinearOpPass::ApplyImpl(
ir::Graph* graph) const {
const std::string pattern_name = const std::string pattern_name =
"delete_weight_dequant_linear_op_encoder_pattern"; "delete_weight_quantdequant_linear_op_pattern";
FusePassBase::Init(pattern_name, graph); FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd; GraphPatternDetector gpd;
auto* scope = param_scope(); auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(scope, PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Scope in DeleteWeightDequantLinearOpEncoderPass " "Scope in TrtDeleteWeightQuantDequantLinearOpPass should not be "
"should not be null.")); "null."));
// Create pattern // Create pattern
patterns::DeleteWeightDequantLinearOpEncoderPattern pattern( patterns::DeleteWeightQuantDequantLinearOpPattern pattern(
gpd.mutable_pattern(), pattern_name); gpd.mutable_pattern(), pattern_name);
pattern(); pattern();
int found_count = 0; int found_count = 0;
bool is_int8 = false;
// Device context
auto* dev_ctx = static_cast<phi::CPUContext*>(
platform::DeviceContextPool::Instance().Get(platform::CPUPlace()));
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_NODES; GET_NODES;
/* /*
if (!IsCompat(subgraph, g)) { if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass " LOG(WARNING) << "trt_delete_weight_dequant_linear_op_pass "
"compat check failed."; "compat check failed.";
return; return;
} }
*/ */
is_int8 = true;
std::unordered_set<const Node*> nodes2rm = {}; std::unordered_set<const Node*> nodes2rm = {};
int bit_length = PADDLE_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
auto* any_op2_desc = any_op2->Op(); auto* any_op2_desc = any_op2->Op();
// get weight tensor
auto* weight_tensor = scope->GetVar(weight_dequantize_linear_op_x->Name())
->GetMutable<phi::DenseTensor>();
int8_t* quantized_weight_data = weight_tensor->data<int8_t>();
auto w_dims = weight_tensor->dims();
// Get weight scale // Get weight scale
std::vector<float> weight_scale; std::vector<float> weight_scale;
auto* weight_scale_tensor = auto* weight_scale_tensor =
scope->GetVar(weight_dequantize_linear_op_scale->Name()) scope->GetVar(weight_dequantize_linear_op_scale->Name())
->GetMutable<phi::DenseTensor>(); ->GetMutable<phi::DenseTensor>();
auto weight_scale_nums = weight_scale_tensor->numel();
if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT32) {
float* weight_scale_data = weight_scale_tensor->data<float>(); float* weight_scale_data = weight_scale_tensor->data<float>();
auto weight_scale_nums = weight_scale_tensor->numel();
for (int i = 0; i < weight_scale_nums; i++) { for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(weight_scale_data[i]); weight_scale.push_back(weight_scale_data[i] / range);
}
} else if (weight_scale_tensor->dtype() ==
paddle::experimental::DataType::FLOAT16) {
phi::dtype::float16* weight_scale_data =
weight_scale_tensor->data<phi::dtype::float16>();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(static_cast<float>(weight_scale_data[i]));
}
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"%d is not supported.", weight_scale_tensor->dtype()));
} }
// dequant weight
std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(weight_tensor->numel());
int quant_axis = PADDLE_GET_CONST( int quant_axis = PADDLE_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis")); int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP if (quant_axis == -1) { // per_layer quant_dequant: all OP
...@@ -337,13 +342,74 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const { ...@@ -337,13 +342,74 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const {
"When quant_axis == -1 means use per_layer " "When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1.")); "quant_dequant, weight_scale'number should be 1."));
// Add attr to anyop 2 // float(weight) * scale
any_op2_desc->SetAttr("weight_scale", weight_scale[0]); for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data_tmp[i] =
static_cast<float>(quantized_weight_data[i]) * weight_scale[0];
}
} else if (quant_axis == 0) { // per_channel quant_dequant: conv2d,
// depthwise_conv2d, conv2d_fusion
PADDLE_ENFORCE_EQ(
weight_scale_nums,
w_dims[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));
PADDLE_ENFORCE_EQ(w_dims.size(),
4,
platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel "
"quant_dequant, (conv2d, depthwise_conv2d, "
"conv2d_fusion)'s weight dims should be 4."));
for (int i = 0; i < weight_tensor->numel(); i++) {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) *
weight_scale[i / inner_size];
}
} else if (quant_axis == 1) {
PADDLE_ENFORCE_EQ(
weight_scale_nums,
w_dims[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));
if (w_dims.size() == 4) { // conv2d_transpose
std::string quantized_op_type = any_op2->Op()->Type();
PADDLE_ENFORCE_EQ(
quantized_op_type,
"conv2d_transpose",
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"only conv2d_transpose weight dims equal 4."));
for (int i = 0; i < weight_tensor->numel(); i++) {
int inner_size = w_dims[2] * w_dims[3];
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) *
weight_scale[(i / inner_size) % w_dims[1]];
}
} else if (w_dims.size() == 2) {
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) *
weight_scale[i % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When quant_axis == 1 , weight dims should be 2 or 4, please check "
"your model "));
}
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::InvalidArgument(
"Delete Weight Dequant Linear Op Encoder Pass is not supported for " "quant_axis should be -1 or 0 or 1, please check your model "
"per-channel quantization")); "OP'attribute "));
} }
weight_tensor->clear(); // clear int weight
weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims)));
float* new_quantized_weight_data = dev_ctx->HostAlloc<float>(
weight_tensor, weight_tensor->numel() * sizeof(float));
memcpy(new_quantized_weight_data,
weight_data_tmp.data(),
weight_tensor->numel() * sizeof(float));
nodes2rm.insert(weight_dequantize_linear_op_scale); nodes2rm.insert(weight_dequantize_linear_op_scale);
nodes2rm.insert(weight_dequantize_linear_op); nodes2rm.insert(weight_dequantize_linear_op);
...@@ -358,7 +424,6 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const { ...@@ -358,7 +424,6 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const {
found_count++; found_count++;
}; };
gpd(graph, handler); gpd(graph, handler);
graph->Set("enable_int8", new bool(is_int8));
AddStatis(found_count); AddStatis(found_count);
} }
...@@ -366,5 +431,5 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const { ...@@ -366,5 +431,5 @@ void DeleteWeightDequantLinearOpEncoderPass::ApplyImpl(ir::Graph* graph) const {
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
REGISTER_PASS(delete_weight_dequant_linear_op_encoder_pass, REGISTER_PASS(trt_delete_weight_dequant_linear_op_pass,
paddle::framework::ir::DeleteWeightDequantLinearOpEncoderPass); paddle::framework::ir::TrtDeleteWeightQuantDequantLinearOpPass);
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h" #include "paddle/fluid/framework/ir/fuse_pass_base.h"
...@@ -20,10 +21,10 @@ namespace paddle { ...@@ -20,10 +21,10 @@ namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
class DeleteWeightDequantLinearOpDecoderPass : public FusePassBase { class TrtDeleteWeightQuantDequantLinearOpPass : public FusePassBase {
public: public:
DeleteWeightDequantLinearOpDecoderPass(); TrtDeleteWeightQuantDequantLinearOpPass();
virtual ~DeleteWeightDequantLinearOpDecoderPass() {} virtual ~TrtDeleteWeightQuantDequantLinearOpPass() {}
protected: protected:
void ApplyImpl(ir::Graph* graph) const override; void ApplyImpl(ir::Graph* graph) const override;
......
...@@ -90,7 +90,7 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -90,7 +90,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"delete_fill_constant_op_pass", // "delete_fill_constant_op_pass", //
"delete_quant_dequant_op_pass", // "delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", // "delete_quant_dequant_filter_op_pass", //
"delete_weight_dequant_linear_op_pass", // "trt_delete_weight_dequant_linear_op_pass", //
"delete_quant_dequant_linear_op_pass", // "delete_quant_dequant_linear_op_pass", //
"identity_scale_op_clean_pass", // "identity_scale_op_clean_pass", //
"add_support_int8_pass", // "add_support_int8_pass", //
...@@ -161,8 +161,7 @@ const std::vector<std::string> kLiteSubgraphPasses({ ...@@ -161,8 +161,7 @@ const std::vector<std::string> kLiteSubgraphPasses({
const std::vector<std::string> kGpuLowerPrecisionPasses{ const std::vector<std::string> kGpuLowerPrecisionPasses{
"simplify_with_basic_ops_pass", "simplify_with_basic_ops_pass",
"delete_quant_dequant_linear_op_pass", "delete_quant_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_encoder_pass", "delete_weight_dequant_linear_op_pass",
"delete_weight_dequant_linear_op_decoder_pass",
"map_depthwise_conv_to_conv_pass", "map_depthwise_conv_to_conv_pass",
"conv_bn_fuse_pass", "conv_bn_fuse_pass",
"conv_eltwiseadd_bn_fuse_pass", "conv_eltwiseadd_bn_fuse_pass",
...@@ -210,8 +209,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -210,8 +209,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"is_test_pass", // "is_test_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"delete_quant_dequant_linear_op_pass", // "delete_quant_dequant_linear_op_pass", //
"delete_weight_dequant_linear_op_encoder_pass", // "delete_weight_dequant_linear_op_pass", //
"delete_weight_dequant_linear_op_decoder_pass", //
"map_depthwise_conv_to_conv_pass", // "map_depthwise_conv_to_conv_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"conv_eltwiseadd_bn_fuse_pass", // "conv_eltwiseadd_bn_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册