未验证 提交 1b58ce14 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle inference] support new quant_model (#41049)

* paddle inference support new quant_model
上级 4a09da02
......@@ -86,6 +86,8 @@ pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference)
pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference)
pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -19,11 +19,7 @@ namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(prev_op); \
GET_IR_NODE(prev_out); \
GET_IR_NODE(quant_op); \
GET_IR_NODE(quant_out);
#define GET_NODES GET_IR_NODE(quant_op);
void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "add_support_int8";
......@@ -37,10 +33,57 @@ void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
if (prev_op->Op()->HasAttr("out_threshold") &&
quant_op->Op()->HasAttr("out_threshold")) {
quant_op->Op()->SetAttr("support_int8", true);
bool inscale_flag = false;
bool outscale_flag = false;
auto* quanted_op_desc = quant_op->Op();
// If inputs'tensors have the inputs_scale, then save it's index in
// input_quant_tensor_index
// OP'Attr hasn't std::vector<std::pair< >>. To do: Support multi-tensor
// scale for one input
for (size_t i = 0; i < quanted_op_desc->InputNames().size(); i++) {
if (quanted_op_desc->Input(quanted_op_desc->InputNames()[i]).size() > 0 &&
quanted_op_desc->HasAttr(
"Input_scale_" +
quanted_op_desc->Input(quanted_op_desc->InputNames()[i])[0])) {
inscale_flag = true;
quanted_op_desc->SetAttr(
quanted_op_desc->InputNames()[i],
quanted_op_desc->GetAttr(
"Input_scale_" +
quanted_op_desc->Input(quanted_op_desc->InputNames()[i])[0]));
}
}
// If outputs'tensors have the outputs_scale, then save it's index in
// output_quant_tensor_index
// OP'Attr hasn't std::vector<std::pair< >>. To do: Support multi-tensor
// scale for one output
for (auto out_node : quant_op->outputs) {
for (auto out_op_node : out_node->outputs) {
for (auto name : out_op_node->Op()->InputNames()) {
for (auto input_name : out_op_node->Op()->Input(name)) {
if (out_op_node->Op()->HasAttr("Input_scale_" + input_name)) {
for (size_t i = 0; i < quanted_op_desc->OutputNames().size();
i++) {
if (quanted_op_desc->Output(quanted_op_desc->OutputNames()[i])
.size() > 0 &&
input_name ==
quanted_op_desc->Output(
quanted_op_desc->OutputNames()[i])[0]) {
outscale_flag = true;
quanted_op_desc->SetAttr(
quanted_op_desc->OutputNames()[i],
out_op_node->Op()->GetAttr("Input_scale_" + input_name));
}
}
}
}
}
}
}
quanted_op_desc->SetAttr("support_int8", inscale_flag && outscale_flag);
quanted_op_desc->Flush();
found_count++;
};
gpd(graph, handler);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(quantize_linear_op_x); \
GET_IR_NODE(quantize_linear_op_scale); \
GET_IR_NODE(quantize_linear_op); \
GET_IR_NODE(quantize_linear_op_out); \
GET_IR_NODE(dequantize_linear_op); \
GET_IR_NODE(dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
AddOpCompat(OpCompat("quantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End();
}
// Delete quantize_linear_op dequantize_linear_op, then add input_scales
void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_quantdequant_linear_op_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::InvalidArgument(
"Scope in DeleteQuantDequantLinearOpPass should not be null."));
// Create pattern
patterns::DeleteQuantDequantLinearOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_quant_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
std::unordered_set<const Node*> nodes2rm = {};
int bit_length =
BOOST_GET_CONST(int, quantize_linear_op->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
// Get input scale from tensor
const LoDTensor& input_scale_tensor =
scope->GetVar(quantize_linear_op_scale->Name())->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true,
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0] / range;
auto* any_op2_desc = any_op2->Op();
any_op2_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(),
input_scale);
nodes2rm.insert(quantize_linear_op_scale);
nodes2rm.insert(quantize_linear_op);
nodes2rm.insert(quantize_linear_op_out);
nodes2rm.insert(dequantize_linear_op);
nodes2rm.insert(dequantize_linear_op_out);
// link x to any_op2
any_op2_desc->RenameInput(dequantize_linear_op_out->Var()->Name(),
quantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(quantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm);
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_quant_dequant_linear_op_pass,
paddle::framework::ir::DeleteQuantDequantLinearOpPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class DeleteQuantDequantLinearOpPass : public FusePassBase {
public:
DeleteQuantDequantLinearOpPass();
virtual ~DeleteQuantDequantLinearOpPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -61,7 +61,6 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
GET_NODES;
int bit_length =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
// Get input scale from tensor
std::string input_scale_var_name =
......@@ -76,7 +75,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0] / range;
float input_scale = input_scale_data[0];
// Set input scale in attr, and relink nodes
std::string input_name = input->Var()->Name();
......@@ -85,12 +84,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
for (auto* quantized_node : outlinks) {
auto op_desc = quantized_node->Op();
std::string quantized_op_type = op_desc->Type();
if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
op_desc->SetAttr("X_scale", input_scale);
} else {
op_desc->SetAttr("Input_scale", input_scale);
}
op_desc->SetAttr("Input_scale", input_scale);
op_desc->SetAttr("bit_length", bit_length);
op_desc->RenameInput(quant_dequant_output_name, input_name);
op_desc->Flush();
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(weight_dequantize_linear_op_x); \
GET_IR_NODE(weight_dequantize_linear_op_scale); \
GET_IR_NODE(weight_dequantize_linear_op); \
GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
AddOpCompat(OpCompat("quantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("depthwise_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End();
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
// Delete dequantize_linear_op, then dequantize weight
void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name =
"delete_weight_quantdequant_linear_op_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::InvalidArgument(
"Scope in DeleteWeightQuantDequantLinearOpPass should not be null."));
// Create pattern
patterns::DeleteWeightQuantDequantLinearOpPattern pattern(
gpd.mutable_pattern(), pattern_name);
pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
std::unordered_set<const Node*> nodes2rm = {};
int bit_length = BOOST_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
auto* any_op2_desc = any_op2->Op();
// get weight tensor
auto* weight_tensor = scope->GetVar(weight_dequantize_linear_op_x->Name())
->GetMutable<LoDTensor>();
int8_t* quantized_weight_data =
weight_tensor->mutable_data<int8_t>(platform::CPUPlace());
auto w_dims = weight_tensor->dims();
// Get weight scale
std::vector<float> weight_scale;
auto* weight_scale_tensor =
scope->GetVar(weight_dequantize_linear_op_scale->Name())
->GetMutable<LoDTensor>();
float* weight_scale_data =
weight_scale_tensor->mutable_data<float>(platform::CPUPlace());
auto weight_scale_nums = weight_scale_tensor->numel();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(weight_scale_data[i] / range);
}
// dequant weight
std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(weight_tensor->numel());
int quant_axis = BOOST_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums, 1,
platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."));
// float(weight) * scale
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data_tmp[i] =
static_cast<float>(quantized_weight_data[i]) * weight_scale[0];
}
} else if (quant_axis == 0) { // per_channel quant_dequant: conv2d,
// depthwise_conv2d, conv2d_fusion
PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));
PADDLE_ENFORCE_EQ(w_dims.size(), 4,
platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel "
"quant_dequant, (conv2d, depthwise_conv2d, "
"conv2d_fusion)'s weight dims should be 4."));
for (int i = 0; i < weight_tensor->numel(); i++) {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) *
weight_scale[i / inner_size];
}
} else if (quant_axis == 1) {
PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));
if (w_dims.size() == 4) { // conv2d_transpose
std::string quantized_op_type = any_op2->Op()->Type();
PADDLE_ENFORCE_EQ(
quantized_op_type, "conv2d_transpose",
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"only conv2d_transpose weight dims equal 4."));
for (int i = 0; i < weight_tensor->numel(); i++) {
int inner_size = w_dims[2] * w_dims[3];
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) *
weight_scale[(i / inner_size) % w_dims[1]];
}
} else if (w_dims.size() == 2) {
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) *
weight_scale[i % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When quant_axis == 1 , weight dims should be 2 or 4, please check "
"your model "));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"quant_axis should be -1 or 0 or 1, please check your model "
"OP'attribute "));
}
weight_tensor->clear(); // clear int weight
weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims)));
float* new_quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_quantized_weight_data, weight_data_tmp.data(),
weight_tensor->numel() * sizeof(float));
nodes2rm.insert(weight_dequantize_linear_op_scale);
nodes2rm.insert(weight_dequantize_linear_op);
nodes2rm.insert(weight_dequantize_linear_op_out);
// relink weight to any_op2
any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(),
weight_dequantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm);
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_weight_dequant_linear_op_pass,
paddle::framework::ir::DeleteWeightQuantDequantLinearOpPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class DeleteWeightQuantDequantLinearOpPass : public FusePassBase {
public:
DeleteWeightQuantDequantLinearOpPass();
virtual ~DeleteWeightQuantDequantLinearOpPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -226,23 +226,34 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
// For anakin subgraph int8
// When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
// fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass
// will add "input_scale", "weight_scale" which are extracted from
// will add "input_scale" which are extracted from
// fake_quant op and fake_dequant op to mul op, and then delete the
// fake_quant op and fake_dequant op in the graph. If the mul op has the
// scale info, we should add those to the fused fc.
auto* mul_op_desc = mul->Op();
auto* elementwise_add_op_desc = elementwise_add->Op();
if (mul_op_desc->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", mul_op_desc->GetAttr("enable_int8"));
desc.SetAttr("Input_scale", mul_op_desc->GetAttr("X_scale"));
desc.SetAttr("weight_scale", mul_op_desc->GetAttr("weight_scale"));
if (mul_op_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", mul_op_desc->GetAttr("out_scale"));
auto elementwise_desc = elementwise_add->Op();
if (elementwise_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale"));
}
auto* elementwise_add_op_desc = elementwise_add->Op();
if (mul_op_desc->HasAttr("Input_scale")) {
desc.SetAttr("Input_scale", mul_op_desc->GetAttr("Input_scale"));
}
bool inscale_flag = false;
bool outscale_flag = false;
if (mul_op_desc->HasAttr("X")) {
desc.SetAttr("X", mul_op_desc->GetAttr("X"));
inscale_flag = true;
}
if (elementwise_add_op_desc->HasAttr("Out")) {
desc.SetAttr("Out", elementwise_add_op_desc->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag && outscale_flag);
// if we can find out_threshold in elementwise_add, then set it as the
// out_thrshold of fc
auto out_threshold_attr =
......
......@@ -298,8 +298,7 @@ void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
......@@ -372,9 +371,7 @@ void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1);
if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale",
matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold"));
}
......@@ -451,8 +448,7 @@ void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
}
if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold"));
}
......@@ -532,8 +528,7 @@ void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
......@@ -677,8 +672,7 @@ void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
......@@ -765,8 +759,7 @@ void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
......
......@@ -2949,6 +2949,84 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() {
any_op2->LinksFrom({quant_dequant_out});
}
void patterns::DeleteWeightQuantDequantLinearOpPattern::operator()() {
auto weight_dequantize_linear_op_x =
pattern->NewNode(weight_dequantize_linear_op_x_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "X")
->assert_is_persistable_var();
auto weight_dequantize_linear_op_scale =
pattern->NewNode(weight_dequantize_linear_op_scale_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "Scale")
->assert_is_persistable_var();
auto weight_dequantize_linear_op =
pattern->NewNode(weight_dequantize_linear_op_repr())
->assert_is_op("dequantize_linear");
auto weight_dequantize_linear_op_out =
pattern->NewNode(weight_dequantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y");
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
weight_dequantize_linear_op
->LinksFrom(
{weight_dequantize_linear_op_x, weight_dequantize_linear_op_scale})
.LinksTo({weight_dequantize_linear_op_out});
any_op2->LinksFrom({weight_dequantize_linear_op_out});
}
void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
auto quantize_linear_op_x = pattern->NewNode(quantize_linear_op_x_repr())
->AsInput()
->assert_is_op_input("quantize_linear", "X");
auto quantize_linear_op_scale =
pattern->NewNode(quantize_linear_op_scale_repr())
->AsInput()
->assert_is_op_input("quantize_linear", "Scale")
->assert_is_persistable_var();
auto quantize_linear_op = pattern->NewNode(quantize_linear_op_repr())
->assert_is_op("quantize_linear");
auto quantize_linear_op_out =
pattern->NewNode(quantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("quantize_linear", "Y")
->assert_is_op_input("dequantize_linear", "X")
->assert_var_not_persistable();
// Can not add this node. Todo: Wangzheee
/*
auto dequantize_linear_op_scale =
pattern->NewNode(dequantize_linear_op_scale_repr())
->assert_is_op_input("dequantize_linear", "Scale")
->AsIntermediate();
*/
auto dequantize_linear_op = pattern->NewNode(dequantize_linear_op_repr())
->assert_is_op("dequantize_linear");
auto dequantize_linear_op_out =
pattern->NewNode(dequantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y");
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
quantize_linear_op
->LinksFrom({quantize_linear_op_x, quantize_linear_op_scale})
.LinksTo({quantize_linear_op_out});
dequantize_linear_op->LinksFrom({quantize_linear_op_out})
.LinksTo({dequantize_linear_op_out});
any_op2->LinksFrom({dequantize_linear_op_out});
}
PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
const std::string &op_name, bool with_reshape_xshape,
bool with_transpose_xshape) {
......@@ -3311,25 +3389,14 @@ PDNode *patterns::LayerNorm::operator()() {
return shift_out;
}
// Add support int8 flag
// Add support int8 flag and out_threshold
PDNode *patterns::AddSupportInt8::operator()() {
auto prev_op =
pattern->NewNode(prev_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return node->Op()->HasAttr("out_threshold") ? true : false;
});
auto prev_out = pattern->NewNode(prev_out_repr())->assert_is_var();
auto quant_op =
pattern->NewNode(quant_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return node->Op()->HasAttr("out_threshold") ? true : false;
});
auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op();
auto quant_out =
pattern->NewNode(quant_out_repr())->assert_is_var()->AsOutput();
prev_op->LinksTo({prev_out});
prev_out->LinksTo({quant_op});
pattern->NewNode(quant_out_repr())
->assert_is_var()
->assert_more([&](Node *node) { return node->outputs.size() > 0; })
->AsOutput();
quant_op->LinksTo({quant_out});
return quant_out;
}
......
......@@ -1702,6 +1702,40 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase {
PATTERN_DECL_NODE(any_op2);
};
struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase {
DeleteWeightQuantDequantLinearOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope,
"delete_weight_quant_dequant_linear_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(weight_dequantize_linear_op_x);
PATTERN_DECL_NODE(weight_dequantize_linear_op_scale);
PATTERN_DECL_NODE(weight_dequantize_linear_op);
PATTERN_DECL_NODE(weight_dequantize_linear_op_out);
PATTERN_DECL_NODE(any_op2);
};
struct DeleteQuantDequantLinearOpPattern : public PatternBase {
DeleteQuantDequantLinearOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope,
"delete_quant_dequant_linear_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(quantize_linear_op_x);
PATTERN_DECL_NODE(quantize_linear_op_scale);
PATTERN_DECL_NODE(quantize_linear_op);
PATTERN_DECL_NODE(quantize_linear_op_out);
PATTERN_DECL_NODE(dequantize_linear_op);
// PATTERN_DECL_NODE(dequantize_linear_op_scale); // Can not add this node.
// Todo: Wangzheee
PATTERN_DECL_NODE(dequantize_linear_op_out);
PATTERN_DECL_NODE(any_op2);
};
// Reshape + Transpose + Matmul
// named nodes:
// reshape_op, reshape_out, reshape_xshape,
......@@ -1887,8 +1921,6 @@ struct AddSupportInt8 : public PatternBase {
: PatternBase(pattern, name_scope, "Add_support_int8") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(prev_out);
PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(quant_out);
};
......
......@@ -862,43 +862,30 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
multihead_op_desc.SetAttr("head_number", head_number);
auto* mul0_op_desc = mul0->Op();
auto* mul1_op_desc = mul1->Op();
auto* mul2_op_desc = mul2->Op();
if (mul0_op_desc->HasAttr("enable_int8")) {
multihead_op_desc.SetAttr("enable_int8",
mul0_op_desc->GetAttr("enable_int8"));
// all mul op has same input.
// all mul op has same input.
if (multihead_op_desc.HasAttr("Input_scale")) {
multihead_op_desc.SetAttr("Input_scale",
mul0_op_desc->GetAttr("X_scale"));
auto weight_scale0 = BOOST_GET_CONST(
std::vector<float>, mul0_op_desc->GetAttr("weight_scale"));
auto weight_scale1 = BOOST_GET_CONST(
std::vector<float>, mul1_op_desc->GetAttr("weight_scale"));
auto weight_scale2 = BOOST_GET_CONST(
std::vector<float>, mul2_op_desc->GetAttr("weight_scale"));
auto weight_max = std::max(weight_scale0, weight_scale1);
weight_max = std::max(weight_max, weight_scale2);
multihead_op_desc.SetAttr("weight_scale", weight_max);
auto* add0_op_desc = eltadd0->Op();
auto* add1_op_desc = eltadd1->Op();
auto* add2_op_desc = eltadd2->Op();
if (add0_op_desc->HasAttr("out_threshold")) {
auto out_scale0 =
BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold"));
auto out_scale1 =
BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold"));
auto out_scale2 =
BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold"));
auto out_scale_max = std::max(out_scale0, out_scale1);
out_scale_max = std::max(out_scale_max, out_scale2);
multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max);
}
mul0_op_desc->GetAttr("Input_scale"));
}
auto* add0_op_desc = eltadd0->Op();
auto* add1_op_desc = eltadd1->Op();
auto* add2_op_desc = eltadd2->Op();
if (add0_op_desc->HasAttr("out_threshold")) {
auto out_scale0 =
BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold"));
auto out_scale1 =
BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold"));
auto out_scale2 =
BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold"));
auto out_scale_max = std::max(out_scale0, out_scale1);
out_scale_max = std::max(out_scale_max, out_scale2);
multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max);
}
auto* softmax_qk_op_desc = softmax_qk->Op();
auto* matmul_qk_op_desc = matmul_qk->Op();
if (matmul_qk_op_desc->HasAttr("X_scale")) {
if (matmul_qk_op_desc->HasAttr("Input_scale")) {
multihead_op_desc.SetAttr("qkv2context_plugin_int8", true);
if (softmax_qk_op_desc->HasAttr("out_threshold")) {
auto qkv_plugin_scale = BOOST_GET_CONST(
......
......@@ -341,7 +341,6 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
Node* output_scale = subgraph.at(pattern.GetPDNode("output_scale_node"));
Node* output_act = subgraph.at(pattern.GetPDNode("output_act_node"));
int bit_length = BOOST_GET_CONST(int, quant->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
// Get input scale from tensor
std::string input_scale_var_name = quant->Op()->Input("InScale").front();
......@@ -356,7 +355,7 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
float in_scale = input_scale_data[0];
float scale_value = in_scale / range;
float scale_value = in_scale;
// Set input scale in attr, and relink nodes
std::string input_act_name = input_act->Var()->Name();
......@@ -369,11 +368,10 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "fc" ||
quantized_op_type == "conv2d_transpose") {
quantized_op_type == "conv2d_transpose" ||
quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
op_desc->SetAttr("Input_scale", scale_value);
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
op_desc->SetAttr("X_scale", scale_value);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported quantized op type %s.", quantized_op_type));
......@@ -619,7 +617,6 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
new_op_desc.SetInput("X", {new_input});
new_op_desc.SetOutput("Out", {new_output});
}
new_op_desc.SetAttr("weight_scale", weight_scale);
new_op_desc.Flush();
auto* new_op = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(quantized_op_input_node, new_op);
......
......@@ -297,11 +297,24 @@ void TrtMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("transpose_Y", matmul_op->Op()->GetAttr("transpose_Y"));
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
bool inscale_flag = false;
bool outscale_flag = false;
if (matmul_op->Op()->HasAttr("X")) {
desc.SetAttr("X", matmul_op->Op()->GetAttr("X"));
inscale_flag = true;
}
if (matmul_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag && outscale_flag);
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
......@@ -370,12 +383,23 @@ void TrtMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("transpose_Y", matmul_v2_op->Op()->GetAttr("trans_y"));
if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale",
matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold"));
}
bool inscale_flag = false;
bool outscale_flag = false;
if (matmul_v2_op->Op()->HasAttr("X")) {
desc.SetAttr("X", matmul_v2_op->Op()->GetAttr("X"));
inscale_flag = true;
}
if (matmul_v2_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_v2_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag && outscale_flag);
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_v2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_v2_in_y, mul_node);
......@@ -448,11 +472,23 @@ void TrtMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
}
if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold"));
}
bool inscale_flag = false;
bool outscale_flag = false;
if (matmul_v2_op->Op()->HasAttr("X")) {
desc.SetAttr("X", matmul_v2_op->Op()->GetAttr("X"));
inscale_flag = true;
}
if (matmul_v2_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_v2_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag && outscale_flag);
auto matmul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_v2_in_x, matmul_node);
IR_NODE_LINK_TO(matmul_v2_in_y, matmul_node);
......@@ -530,11 +566,24 @@ void TrtSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
bool inscale_flag_x = false;
bool outscale_flag = false;
if (squeeze2_op->Op()->HasAttr("X")) {
desc.SetAttr("X", squeeze2_op->Op()->GetAttr("X"));
inscale_flag_x = true;
}
if (matmul_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag_x && outscale_flag);
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(squeeze2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
......@@ -675,11 +724,24 @@ void TrtReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
bool inscale_flag_x = false;
bool outscale_flag = false;
if (reshape2_op->Op()->HasAttr("X")) {
desc.SetAttr("X", reshape2_op->Op()->GetAttr("X"));
inscale_flag_x = true;
}
if (matmul_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag_x && outscale_flag);
if (!IsCompat(desc)) {
LOG(WARNING)
<< "TrtReshape2MatmulFusePass in out mul op compat failed.";
......@@ -763,11 +825,24 @@ void TrtFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold"));
}
bool inscale_flag_x = false;
bool outscale_flag = false;
if (flatten2_op->Op()->HasAttr("X")) {
desc.SetAttr("X", flatten2_op->Op()->GetAttr("X"));
inscale_flag_x = true;
}
if (matmul_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag_x && outscale_flag);
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(flatten2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
......
......@@ -76,10 +76,13 @@ void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
const std::vector<std::string> kTRTSubgraphPasses({
"adaptive_pool2d_convert_global_pass",
"shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", //
"shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", //
"delete_weight_dequant_linear_op_pass", //
"delete_quant_dequant_linear_op_pass", //
"add_support_int8_pass", //
// "fc_fuse_pass", //
"simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", //
......@@ -98,9 +101,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"add_support_int8_pass",
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
"tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we
......
......@@ -68,12 +68,6 @@ class ActivationOpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode);
if (op_desc.HasAttr("out_scale")) {
#if IS_TRT_VERSION_GE(5130)
float out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_scale"));
engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale);
#endif
}
}
protected:
......
......@@ -49,11 +49,11 @@ class AffineChannelOpConverter : public OpConverter {
auto* scale_v = scope.FindVar(scale_name);
auto* scale_t = scale_v->GetMutable<framework::LoDTensor>();
float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t, false);
float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t);
auto* bias_v = scope.FindVar(bias_name);
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t, false);
float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t);
// tensorrt scalend layer only support spatial dims >= 2,
// so nhwc is not availabe (spatial dims == 0)
......
......@@ -49,18 +49,11 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
if (enable_int8) {
#if IS_TRT_VERSION_GE(5000)
float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t,
true, weight_scale);
float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine->SetTensorDynamicRange(X, in_scale);
#endif
} else {
weight_data =
engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, false);
}
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t);
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 4UL,
platform::errors::InvalidArgument(
......@@ -115,7 +108,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
auto* bias_tensor = scope.GetVar(op_desc.Input("Bias").front());
auto* bias_tensor_data = bias_tensor->GetMutable<framework::LoDTensor>();
bias_data = engine->GetWeightCPUData(op_desc.Input("Bias").front(),
bias_tensor_data, false);
bias_tensor_data);
bias_size = static_cast<size_t>(bias_tensor_data->numel());
}
......
......@@ -48,17 +48,10 @@ void ConvertConv3d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
bool enable_int8 = op_desc.HasAttr("enable_int8");
if (enable_int8) {
float in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t,
true, weight_scale);
float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine->SetTensorDynamicRange(X, in_scale);
} else {
weight_data =
engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, false);
}
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t);
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 5UL,
platform::errors::InvalidArgument(
......
......@@ -47,8 +47,7 @@ class DeformableConvOpConverter : public OpConverter {
auto* filter_var = scope.FindVar(filter_name);
auto* filter_tensor = filter_var->GetMutable<framework::LoDTensor>();
float* filter_data =
engine_->GetWeightCPUData(filter_name, filter_tensor, false);
float* filter_data = engine_->GetWeightCPUData(filter_name, filter_tensor);
const int c_o = filter_tensor->dims()[0];
const int c_i = filter_tensor->dims()[1];
......
......@@ -51,8 +51,7 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
float* weight_data = nullptr;
auto output_name = op_desc.Output("Out")[0];
weight_data =
engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t, false);
weight_data = engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t);
nvinfer1::Dims dims_x = X->getDimensions();
auto regist_eltwise_weight = [&](nvinfer1::ScaleMode scale_mode) {
......@@ -112,13 +111,6 @@ class ElementwiseWeightOpConverter : public OpConverter {
RreplenishLayerAndOutput(layer, "elementwise_" + op_type_,
{output_name}, test_mode);
}
if (op_desc.HasAttr("enable_int8")) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("X_scale"));
float x_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale"));
engine_->SetTensorDynamicRange(X, x_scale);
#endif
}
};
if (engine_->with_dynamic_shape()) {
......@@ -222,16 +214,6 @@ class ElementwiseTensorOpConverter : public OpConverter {
auto common_func = [&](nvinfer1::ILayer* layer) {
RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode);
if (op_desc.HasAttr("enable_int8")) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("X_scale"));
CHECK(op_desc.HasAttr("Y_scale"));
float x_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale"));
float y_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Y_scale"));
engine_->SetTensorDynamicRange(X, x_scale);
engine_->SetTensorDynamicRange(Y, y_scale);
#endif
}
};
if (dims_x.nbDims == dims_y.nbDims) {
......
......@@ -77,7 +77,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
return temp_data;
};
......
......@@ -113,22 +113,20 @@ class FcOpConverter : public OpConverter {
// assigned from CPU memory, which can't be avoided.
float* weight_data = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8");
float in_scale = 0.;
if (enable_int8) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr(i_name + "_scale"));
in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(),
Y_t, true, weight_scale);
bool support_int8 = false;
if (op_desc.HasAttr("support_int8")) {
support_int8 = BOOST_GET_CONST(bool, op_desc.GetAttr("support_int8"));
}
float in_scale = 0;
if (enable_int8 || support_int8) {
if (enable_int8) {
in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
} else {
in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X"));
}
engine_->SetTensorDynamicRange(X, in_scale);
#endif
} else {
weight_data =
engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t, false);
}
weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t);
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL,
platform::errors::InvalidArgument(
......@@ -148,14 +146,18 @@ class FcOpConverter : public OpConverter {
auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output,
TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) {
if (enable_int8) {
if (enable_int8 || support_int8) {
// add conv layer
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in fc layers in int8 mode"));
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
float out_scale = 0;
if (enable_int8) {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in fc layers in int8 mode"));
out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
} else {
out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out"));
}
nvinfer1::DimsHW nv_ksize(1, 1);
auto* fc_layer_int8 =
TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output,
......@@ -235,8 +237,7 @@ class FcOpConverter : public OpConverter {
if (with_bias) {
auto* b_v = scope.GetVar(op_desc.Input("Bias").front());
auto* b_t = b_v->GetMutable<framework::LoDTensor>();
bias_data =
engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t, false);
bias_data = engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t);
bias_num = b_t->numel();
}
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
......@@ -251,7 +252,7 @@ class FcOpConverter : public OpConverter {
// not add Shuffle layer in ernie's multihead.
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[3] == 1 && x_num_col_dims == 2) {
if (enable_int8) {
if (enable_int8 || support_int8) {
// add conv1x1 layer
nvinfer1::DimsHW nv_ksize(1, 1);
auto* fc_layer_int8 =
......@@ -265,8 +266,13 @@ class FcOpConverter : public OpConverter {
op_desc.HasAttr("out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in fc layers in int8 mode"));
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
float out_scale = 0;
if (enable_int8) {
out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
} else {
out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out"));
}
engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0),
out_scale);
nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER(
......@@ -308,7 +314,7 @@ class FcOpConverter : public OpConverter {
auto* reshape_before_fc_layer =
reshape_before_fc(X, x_dim, x_num_col_dims, output_name);
auto* reshape_itensor = reshape_before_fc_layer->getOutput(0);
if (enable_int8) {
if (enable_int8 || support_int8) {
engine_->SetTensorDynamicRange(reshape_itensor, in_scale);
}
regist_fc(reshape_itensor, n_output, weight, bias);
......
......@@ -48,7 +48,7 @@ class GroupNormOpConverter : public OpConverter {
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
return temp_data;
};
......
......@@ -49,8 +49,8 @@ class LeakyReluOpConverter : public OpConverter {
bool enable_int8 = op_desc.HasAttr("enable_int8");
if (enable_int8) {
CHECK(op_desc.HasAttr("X_scale"));
float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale"));
CHECK(op_desc.HasAttr("Input_scale"));
float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine_->SetTensorDynamicRange(input, in_scale);
}
#else
......
......@@ -64,7 +64,9 @@ class MatMulOpConverter : public OpConverter {
: nvinfer1::MatrixOperation::kNONE;
if (op_desc.HasAttr("support_int8") &&
engine_->precision() == AnalysisConfig::Precision::kInt8) {
BOOST_GET_CONST(bool, op_desc.GetAttr("support_int8")) &&
engine_->precision() == AnalysisConfig::Precision::kInt8 &&
platform::GetGPUComputeCapability(0) >= 75) {
if (engine_->with_dynamic_shape()) {
VLOG(3) << "Convert a fluid matmul_op_int8_dynamic to TensorRT "
"MatmulPluginLayer";
......
......@@ -40,22 +40,16 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
float* weight_data = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8");
bool qkv2context_plugin_int8 = op_desc.HasAttr("qkv2context_plugin_int8");
float in_scale = 0.;
if (enable_int8) {
in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data =
engine_->GetWeightCPUData(weight_name, weight_t, true, weight_scale);
if (op_desc.HasAttr("Input_scale")) {
in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine_->SetTensorDynamicRange(input, in_scale);
} else {
weight_data = engine_->GetWeightCPUData(weight_name, weight_t, false);
}
weight_data = engine_->GetWeightCPUData(weight_name, weight_t);
float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t, false);
float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t);
std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(weight_t->numel());
memcpy(weight_data_tmp.data(), weight_data,
......@@ -85,6 +79,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
if (engine_->with_dynamic_shape()) {
if (engine_->use_oss()) {
if (engine_->precision() == AnalysisConfig::Precision::kFloat32) {
PADDLE_THROW(platform::errors::Fatal(
"use use_oss must be int8 or half, not float32."));
}
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data),
static_cast<int32_t>(weight_t->numel())};
......@@ -93,7 +91,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
static_cast<int32_t>(bias_t->numel())};
if (engine_->with_interleaved()) {
VLOG(4) << "fused multihead_matmul op: use_oss and with_interleaved";
if (!enable_int8) {
if (!op_desc.HasAttr("Input_scale")) {
PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8."));
}
......@@ -213,7 +211,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1::ILayer* fc_layer = nullptr;
float dp_probs = 1.0 / 127.0;
if (enable_int8) {
if (op_desc.HasAttr("Input_scale")) {
nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n,
nv_ksize, weight, bias);
......@@ -222,7 +220,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
weight, bias);
}
if (enable_int8) {
if (op_desc.HasAttr("fc_out_threshold")) {
PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), true,
platform::errors::InvalidArgument(
"must have out threshold in multihead layers "
......@@ -241,14 +239,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "2");
assert(creator != nullptr);
int type = static_cast<int>((engine_->WithFp16() == 1)
? nvinfer1::DataType::kHALF
: nvinfer1::DataType::kFLOAT);
if (enable_int8) {
type = static_cast<int>(nvinfer1::DataType::kHALF);
if (qkv2context_plugin_int8) {
type = static_cast<int>(nvinfer1::DataType::kINT8);
}
int type = static_cast<int>(nvinfer1::DataType::kHALF);
if (qkv2context_plugin_int8 &&
(engine_->precision() == AnalysisConfig::Precision::kInt8)) {
type = static_cast<int>(nvinfer1::DataType::kINT8);
}
bool has_mask = true;
int var_seqlen = 1;
......@@ -335,7 +329,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
reshape_before_fc_dim.d[4] = 1;
auto* reshape_before_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
if (enable_int8) {
if (op_desc.HasAttr("Input_scale")) {
engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0),
in_scale);
}
......@@ -346,7 +340,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
// add layer fc
nvinfer1::ILayer* fc_layer = nullptr;
if (enable_int8) {
if (op_desc.HasAttr("Input_scale")) {
nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER(
engine_, Convolution, *reshape_before_fc_layer->getOutput(0), n,
......@@ -357,7 +351,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
n, weight.get(), bias.get());
}
if (enable_int8) {
if (op_desc.HasAttr("fc_out_threshold")) {
PADDLE_ENFORCE_EQ(
op_desc.HasAttr("fc_out_threshold"), true,
platform::errors::InvalidArgument(
......@@ -382,8 +376,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (enable_int8) {
with_fp16 = 1;
if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
with_fp16 = true;
}
plugin::DynamicPluginTensorRT* plugin =
new plugin::QkvToContextPluginDynamic(hidden_in, head_number,
......
......@@ -145,42 +145,68 @@ class OpConverter {
(*it)(op, scope, test_mode);
size_t output_num = op_desc.OutputNames().size();
if (output_num == 1) { // The number of output is 1
if (op_desc.HasAttr("out_threshold")) {
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
std::string output_name = "";
if (op_desc.HasOutput("Output")) {
output_name = op_desc.Output("Output").front();
} else if (op_desc.HasOutput("Out")) {
output_name = op_desc.Output("Out").front();
} else if (op_desc.HasOutput("Y")) {
output_name = op_desc.Output("Y").front();
} else {
PADDLE_THROW(
platform::errors::NotFound("Op %s has out threshold but doesn't "
"have an output named \"Output\", "
"\"Out\" or \"Y\".",
op_desc.Type()));
}
// only one out settensordynamicRange
if (op_desc.HasAttr("out_threshold")) {
float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
std::string output_name = "";
if (op_desc.HasOutput("Output")) {
output_name = op_desc.Output("Output").front();
} else if (op_desc.HasOutput("Out")) {
output_name = op_desc.Output("Out").front();
} else if (op_desc.HasOutput("Y")) {
output_name = op_desc.Output("Y").front();
} else {
PADDLE_THROW(
platform::errors::NotFound("Op %s has out threshold but doesn't "
"have an output named \"Output\", "
"\"Out\" or \"Y\".",
op_desc.Type()));
}
auto* output_itensor = engine->GetITensor(output_name);
engine->SetTensorDynamicRange(output_itensor, out_scale);
VLOG(1) << "Set out scale = " << out_scale << " for tensor "
<< output_name << ".";
}
// outs settensordynamicRange
for (size_t i = 0; i < output_num; ++i) {
if (op_desc.HasAttr("out_" + std::to_string(i) + "_threshold")) {
float out_scale = BOOST_GET_CONST(
float, op_desc.GetAttr("out_" + std::to_string(i) + "_threshold"));
std::string output_name =
op_desc.Output(op_desc.OutputNames()[i]).front();
auto* output_itensor = engine->GetITensor(output_name);
engine->SetTensorDynamicRange(output_itensor, out_scale);
VLOG(1) << "Set out scale = " << out_scale << " for tensor "
<< output_name << ".";
}
} else if (output_num > 1) { // The number of outputs greater than 1
for (size_t i = 0; i < output_num; ++i) {
if (op_desc.HasAttr("out_" + std::to_string(i) + "_threshold")) {
float out_scale = BOOST_GET_CONST(
float,
op_desc.GetAttr("out_" + std::to_string(i) + "_threshold"));
std::string output_name =
op_desc.Output(op_desc.OutputNames()[i]).front();
auto* output_itensor = engine->GetITensor(output_name);
engine->SetTensorDynamicRange(output_itensor, out_scale);
VLOG(1) << "Set out scale = " << out_scale << " for tensor "
<< output_name << ".";
}
}
// quant_dequant_linear support for paddle trt
std::vector<std::string> inputs_name = op_desc.InputNames();
std::vector<std::string> outputs_name = op_desc.OutputNames();
for (size_t i = 0; i < inputs_name.size(); i++) {
if (op_desc.HasAttr(inputs_name[i])) {
std::string input_tensor_name = op_desc.Input(inputs_name[i])[0];
auto* input_itensor = engine->GetITensor(input_tensor_name);
float input_scale =
BOOST_GET_CONST(float, op_desc.GetAttr(inputs_name[i]));
engine->SetTensorDynamicRange(input_itensor, input_scale);
VLOG(1) << "Set input tensor scale = " << input_scale
<< " for tensor: " << input_tensor_name << ".";
}
}
for (size_t i = 0; i < outputs_name.size(); i++) {
if (op_desc.HasAttr(outputs_name[i])) {
std::string output_tensor_name = op_desc.Output(outputs_name[i])[0];
auto* output_itensor = engine->GetITensor(output_tensor_name);
float output_scale =
BOOST_GET_CONST(float, op_desc.GetAttr(outputs_name[i]));
engine->SetTensorDynamicRange(output_itensor, output_scale);
VLOG(1) << "Set output tensor scale = " << output_scale
<< " for tensor: " << output_tensor_name << ".";
}
}
}
......
......@@ -132,11 +132,10 @@ class Pool2dOpConverter : public OpConverter {
}
if (op_desc.HasAttr("enable_int8")) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("X_scale"));
float input_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale"));
CHECK(op_desc.HasAttr("Input_scale"));
float input_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine_->SetTensorDynamicRange(input1, input_scale);
#endif
}
std::vector<int> real_paddings = paddings;
......
......@@ -123,8 +123,9 @@ class Pool3dOpConverter : public OpConverter {
nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]);
nvinfer1::ILayer *layer = nullptr;
if (op_desc.HasAttr("enable_int8")) {
CHECK(op_desc.HasAttr("X_scale"));
float input_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale"));
CHECK(op_desc.HasAttr("Input_scale"));
float input_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine_->SetTensorDynamicRange(input1, input_scale);
}
......
......@@ -70,7 +70,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
return temp_data;
};
......
......@@ -48,7 +48,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
return temp_data;
};
......
......@@ -57,8 +57,8 @@ class PReluOpConverter : public OpConverter {
layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
} else {
#if IS_TRT_VERSION_GE(7000)
float* alpha_weight_data = engine_->GetWeightCPUData(
op_desc.Input("Alpha")[0], alpha_tensor, false);
float* alpha_weight_data =
engine_->GetWeightCPUData(op_desc.Input("Alpha")[0], alpha_tensor);
TensorRTEngine::Weight alpha_weight{
nvinfer1::DataType::kFLOAT, static_cast<void*>(alpha_weight_data),
static_cast<size_t>(alpha_tensor->numel())};
......
......@@ -40,7 +40,7 @@ class SkipLayerNormOpConverter : public OpConverter {
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false);
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
return temp_data;
};
......
......@@ -356,9 +356,7 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
}
float *TensorRTEngine::GetWeightCPUData(const std::string &name,
framework::Tensor *weight_tensor,
bool enable_int8,
const std::vector<float> &scale) {
framework::Tensor *weight_tensor) {
static int name_suffix_counter = 0;
std::string name_suffix = std::to_string(name_suffix_counter);
std::string splitter = "__";
......
......@@ -389,8 +389,7 @@ class TensorRTEngine {
}
float* GetWeightCPUData(const std::string& name,
framework::Tensor* weight_tensor, bool enable_int8,
const std::vector<float>& scale = {});
framework::Tensor* weight_tensor);
// A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage.
......
type: "dequantize_linear"
def {
inputs {
name: "X"
}
inputs {
name: "Scale"
}
inputs {
name: "ZeroPoint"
}
outputs {
name: "Y"
}
attrs {
name: "bit_length"
type: INT
}
attrs {
name: "quant_axis"
type: INT
}
}
extra {
}
......@@ -60,15 +60,7 @@ extra {
type: BOOLEAN
}
attrs {
name: "X_scale"
type: FLOAT
}
attrs {
name: "weight_scale"
type: FLOAT
}
attrs {
name: "out_scale"
name: "Input_scale"
type: FLOAT
}
attrs {
......
type: "quantize_linear"
def {
inputs {
name: "X"
}
inputs {
name: "Scale"
}
inputs {
name: "ZeroPoint"
}
outputs {
name: "Y"
}
attrs {
name: "bit_length"
type: INT
}
attrs {
name: "quant_axis"
type: INT
}
}
extra {
}
......@@ -491,8 +491,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
"x_num_col_dims": 2,
"y_num_col_dims": 1,
"enable_int8": True,
"X_scale": 1.0,
"weight_scale": [1.0],
"Input_scale": 1.0,
}, {
"axis": 2,
"out_threshold": 1.0,
......@@ -504,8 +503,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
"x_num_col_dims": 2,
"y_num_col_dims": 1,
"enable_int8": True,
"X_scale": 1.0,
"weight_scale": [1.0],
"Input_scale": 1.0,
}, {
"axis": 2,
"out_threshold": 1.0,
......@@ -517,8 +515,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
"x_num_col_dims": 2,
"y_num_col_dims": 1,
"enable_int8": True,
"X_scale": 1.0,
"weight_scale": [1.0],
"Input_scale": 1.0,
}, {
"axis": 2,
"out_threshold": 1.0,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册