未验证 提交 1b58ce14 编写于 作者: W Wangzheee 提交者: GitHub

[Paddle inference] support new quant_model (#41049)

* paddle inference support new quant_model
上级 4a09da02
...@@ -86,6 +86,8 @@ pass_library(quant_conv2d_dequant_fuse_pass inference) ...@@ -86,6 +86,8 @@ pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(shuffle_channel_detect_pass inference) pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_op_pass inference)
pass_library(delete_quant_dequant_filter_op_pass inference) pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(delete_weight_dequant_linear_op_pass inference)
pass_library(delete_quant_dequant_linear_op_pass inference)
pass_library(delete_dropout_op_pass inference) pass_library(delete_dropout_op_pass inference)
pass_library(simplify_with_basic_ops_pass base) pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. // Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
// //
// Licensed under the Apache License, Version 2.0 (the "License"); // Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License. // you may not use this file except in compliance with the License.
...@@ -19,11 +19,7 @@ namespace framework { ...@@ -19,11 +19,7 @@ namespace framework {
namespace ir { namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); #define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \ #define GET_NODES GET_IR_NODE(quant_op);
GET_IR_NODE(prev_op); \
GET_IR_NODE(prev_out); \
GET_IR_NODE(quant_op); \
GET_IR_NODE(quant_out);
void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const { void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "add_support_int8"; const std::string pattern_name = "add_support_int8";
...@@ -37,10 +33,57 @@ void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const { ...@@ -37,10 +33,57 @@ void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const {
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) { Graph* g) {
GET_NODES; GET_NODES;
if (prev_op->Op()->HasAttr("out_threshold") &&
quant_op->Op()->HasAttr("out_threshold")) { bool inscale_flag = false;
quant_op->Op()->SetAttr("support_int8", true); bool outscale_flag = false;
auto* quanted_op_desc = quant_op->Op();
// If inputs'tensors have the inputs_scale, then save it's index in
// input_quant_tensor_index
// OP'Attr hasn't std::vector<std::pair< >>. To do: Support multi-tensor
// scale for one input
for (size_t i = 0; i < quanted_op_desc->InputNames().size(); i++) {
if (quanted_op_desc->Input(quanted_op_desc->InputNames()[i]).size() > 0 &&
quanted_op_desc->HasAttr(
"Input_scale_" +
quanted_op_desc->Input(quanted_op_desc->InputNames()[i])[0])) {
inscale_flag = true;
quanted_op_desc->SetAttr(
quanted_op_desc->InputNames()[i],
quanted_op_desc->GetAttr(
"Input_scale_" +
quanted_op_desc->Input(quanted_op_desc->InputNames()[i])[0]));
}
}
// If outputs'tensors have the outputs_scale, then save it's index in
// output_quant_tensor_index
// OP'Attr hasn't std::vector<std::pair< >>. To do: Support multi-tensor
// scale for one output
for (auto out_node : quant_op->outputs) {
for (auto out_op_node : out_node->outputs) {
for (auto name : out_op_node->Op()->InputNames()) {
for (auto input_name : out_op_node->Op()->Input(name)) {
if (out_op_node->Op()->HasAttr("Input_scale_" + input_name)) {
for (size_t i = 0; i < quanted_op_desc->OutputNames().size();
i++) {
if (quanted_op_desc->Output(quanted_op_desc->OutputNames()[i])
.size() > 0 &&
input_name ==
quanted_op_desc->Output(
quanted_op_desc->OutputNames()[i])[0]) {
outscale_flag = true;
quanted_op_desc->SetAttr(
quanted_op_desc->OutputNames()[i],
out_op_node->Op()->GetAttr("Input_scale_" + input_name));
}
}
}
}
}
}
} }
quanted_op_desc->SetAttr("support_int8", inscale_flag && outscale_flag);
quanted_op_desc->Flush();
found_count++; found_count++;
}; };
gpd(graph, handler); gpd(graph, handler);
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_quant_dequant_linear_op_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(quantize_linear_op_x); \
GET_IR_NODE(quantize_linear_op_scale); \
GET_IR_NODE(quantize_linear_op); \
GET_IR_NODE(quantize_linear_op_out); \
GET_IR_NODE(dequantize_linear_op); \
GET_IR_NODE(dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteQuantDequantLinearOpPass::DeleteQuantDequantLinearOpPass() {
AddOpCompat(OpCompat("quantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End();
}
// Delete quantize_linear_op dequantize_linear_op, then add input_scales
void DeleteQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_quantdequant_linear_op_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::InvalidArgument(
"Scope in DeleteQuantDequantLinearOpPass should not be null."));
// Create pattern
patterns::DeleteQuantDequantLinearOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_quant_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
std::unordered_set<const Node*> nodes2rm = {};
int bit_length =
BOOST_GET_CONST(int, quantize_linear_op->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
// Get input scale from tensor
const LoDTensor& input_scale_tensor =
scope->GetVar(quantize_linear_op_scale->Name())->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
paddle::platform::is_cpu_place(input_scale_tensor.place()), true,
platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0] / range;
auto* any_op2_desc = any_op2->Op();
any_op2_desc->SetAttr("Input_scale_" + quantize_linear_op_x->Var()->Name(),
input_scale);
nodes2rm.insert(quantize_linear_op_scale);
nodes2rm.insert(quantize_linear_op);
nodes2rm.insert(quantize_linear_op_out);
nodes2rm.insert(dequantize_linear_op);
nodes2rm.insert(dequantize_linear_op_out);
// link x to any_op2
any_op2_desc->RenameInput(dequantize_linear_op_out->Var()->Name(),
quantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(quantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm);
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_quant_dequant_linear_op_pass,
paddle::framework::ir::DeleteQuantDequantLinearOpPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class DeleteQuantDequantLinearOpPass : public FusePassBase {
public:
DeleteQuantDequantLinearOpPass();
virtual ~DeleteQuantDequantLinearOpPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -61,7 +61,6 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -61,7 +61,6 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
GET_NODES; GET_NODES;
int bit_length = int bit_length =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length")); BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
// Get input scale from tensor // Get input scale from tensor
std::string input_scale_var_name = std::string input_scale_var_name =
...@@ -76,7 +75,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -76,7 +75,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input scale tensor's place should be CPU.")); "Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>(); const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0] / range; float input_scale = input_scale_data[0];
// Set input scale in attr, and relink nodes // Set input scale in attr, and relink nodes
std::string input_name = input->Var()->Name(); std::string input_name = input->Var()->Name();
...@@ -85,12 +84,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -85,12 +84,7 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
for (auto* quantized_node : outlinks) { for (auto* quantized_node : outlinks) {
auto op_desc = quantized_node->Op(); auto op_desc = quantized_node->Op();
std::string quantized_op_type = op_desc->Type(); std::string quantized_op_type = op_desc->Type();
if (quantized_op_type == "mul" || quantized_op_type == "matmul" || op_desc->SetAttr("Input_scale", input_scale);
quantized_op_type == "matmul_v2") {
op_desc->SetAttr("X_scale", input_scale);
} else {
op_desc->SetAttr("Input_scale", input_scale);
}
op_desc->SetAttr("bit_length", bit_length); op_desc->SetAttr("bit_length", bit_length);
op_desc->RenameInput(quant_dequant_output_name, input_name); op_desc->RenameInput(quant_dequant_output_name, input_name);
op_desc->Flush(); op_desc->Flush();
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_weight_dequant_linear_op_pass.h"
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(weight_dequantize_linear_op_x); \
GET_IR_NODE(weight_dequantize_linear_op_scale); \
GET_IR_NODE(weight_dequantize_linear_op); \
GET_IR_NODE(weight_dequantize_linear_op_out); \
GET_IR_NODE(any_op2);
DeleteWeightQuantDequantLinearOpPass::DeleteWeightQuantDequantLinearOpPass() {
AddOpCompat(OpCompat("quantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End();
AddOpCompat(OpCompat("dequantize_linear"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Scale")
.IsTensor()
.End()
.AddInput("ZeroPoint")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Y")
.IsTensor()
.End()
.AddAttr("bit_length")
.IsType<int>()
.End()
.AddAttr("quant_axis")
.IsType<int>()
.End();
AddOpCompat(OpCompat("conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("depthwise_conv2d"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddInput("ResidualData")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
AddOpCompat(OpCompat("mul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("x_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("y_num_col_dims")
.IsNumEQ(1)
.End();
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsBoolEQ(false)
.End()
.AddAttr("trans_y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("alpha")
.IsNumGE(0.99f)
.IsNumLE(1.01f)
.End()
.AddAttr("transpose_X")
.IsBoolEQ(false)
.End()
.AddAttr("transpose_Y")
.IsBoolEQ(false)
.End();
AddOpCompat(OpCompat("fc"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("W")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("in_num_col_dims")
.IsNumGE(1)
.End()
.AddAttr("activation_type")
.IsStringIn({"relu", ""})
.End();
AddOpCompat(OpCompat("conv2d_transpose"))
.AddInput("Input")
.IsTensor()
.End()
.AddInput("Filter")
.IsTensor()
.End()
.AddInput("Bias")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Output")
.IsTensor()
.End()
.AddAttr("output_padding")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("output_size")
.IsType<std::vector<int>>()
.IsOptional()
.End()
.AddAttr("groups")
.IsNumGE(1)
.End()
.AddAttr("dilations")
.IsType<std::vector<int>>()
.End()
.AddAttr("strides")
.IsType<std::vector<int>>()
.End()
.AddAttr("paddings")
.IsType<std::vector<int>>()
.End()
.AddAttr("padding_algorithm")
.IsOptional()
.IsStringIn({"EXPLICIT", "SAME", "VALID"})
.End()
.AddAttr("data_format")
.IsStringIn({"NCHW", "NHWC", "AnyLayout"})
.End();
}
// Delete dequantize_linear_op, then dequantize weight
void DeleteWeightQuantDequantLinearOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name =
"delete_weight_quantdequant_linear_op_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
auto* scope = param_scope();
PADDLE_ENFORCE_NOT_NULL(
scope,
platform::errors::InvalidArgument(
"Scope in DeleteWeightQuantDequantLinearOpPass should not be null."));
// Create pattern
patterns::DeleteWeightQuantDequantLinearOpPattern pattern(
gpd.mutable_pattern(), pattern_name);
pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
/*
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "delete_weight_dequant_linear_op_pass "
"compat check failed.";
return;
}
*/
std::unordered_set<const Node*> nodes2rm = {};
int bit_length = BOOST_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
auto* any_op2_desc = any_op2->Op();
// get weight tensor
auto* weight_tensor = scope->GetVar(weight_dequantize_linear_op_x->Name())
->GetMutable<LoDTensor>();
int8_t* quantized_weight_data =
weight_tensor->mutable_data<int8_t>(platform::CPUPlace());
auto w_dims = weight_tensor->dims();
// Get weight scale
std::vector<float> weight_scale;
auto* weight_scale_tensor =
scope->GetVar(weight_dequantize_linear_op_scale->Name())
->GetMutable<LoDTensor>();
float* weight_scale_data =
weight_scale_tensor->mutable_data<float>(platform::CPUPlace());
auto weight_scale_nums = weight_scale_tensor->numel();
for (int i = 0; i < weight_scale_nums; i++) {
weight_scale.push_back(weight_scale_data[i] / range);
}
// dequant weight
std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(weight_tensor->numel());
int quant_axis = BOOST_GET_CONST(
int, weight_dequantize_linear_op->Op()->GetAttr("quant_axis"));
if (quant_axis == -1) { // per_layer quant_dequant: all OP
PADDLE_ENFORCE_EQ(weight_scale_nums, 1,
platform::errors::InvalidArgument(
"When quant_axis == -1 means use per_layer "
"quant_dequant, weight_scale'number should be 1."));
// float(weight) * scale
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data_tmp[i] =
static_cast<float>(quantized_weight_data[i]) * weight_scale[0];
}
} else if (quant_axis == 0) { // per_channel quant_dequant: conv2d,
// depthwise_conv2d, conv2d_fusion
PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));
PADDLE_ENFORCE_EQ(w_dims.size(), 4,
platform::errors::InvalidArgument(
"When quant_axis == 0 means use per_channel "
"quant_dequant, (conv2d, depthwise_conv2d, "
"conv2d_fusion)'s weight dims should be 4."));
for (int i = 0; i < weight_tensor->numel(); i++) {
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) *
weight_scale[i / inner_size];
}
} else if (quant_axis == 1) {
PADDLE_ENFORCE_EQ(
weight_scale_nums, w_dims[quant_axis],
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"weight_scale'numbers should be equal channels."));
if (w_dims.size() == 4) { // conv2d_transpose
std::string quantized_op_type = any_op2->Op()->Type();
PADDLE_ENFORCE_EQ(
quantized_op_type, "conv2d_transpose",
platform::errors::InvalidArgument(
"When quant_axis == 1 means use per_channel quant_dequant, "
"only conv2d_transpose weight dims equal 4."));
for (int i = 0; i < weight_tensor->numel(); i++) {
int inner_size = w_dims[2] * w_dims[3];
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) *
weight_scale[(i / inner_size) % w_dims[1]];
}
} else if (w_dims.size() == 2) {
for (int i = 0; i < weight_tensor->numel(); i++) {
weight_data_tmp[i] = static_cast<float>(quantized_weight_data[i]) *
weight_scale[i % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When quant_axis == 1 , weight dims should be 2 or 4, please check "
"your model "));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"quant_axis should be -1 or 0 or 1, please check your model "
"OP'attribute "));
}
weight_tensor->clear(); // clear int weight
weight_tensor->Resize(phi::make_ddim(phi::vectorize(w_dims)));
float* new_quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
memcpy(new_quantized_weight_data, weight_data_tmp.data(),
weight_tensor->numel() * sizeof(float));
nodes2rm.insert(weight_dequantize_linear_op_scale);
nodes2rm.insert(weight_dequantize_linear_op);
nodes2rm.insert(weight_dequantize_linear_op_out);
// relink weight to any_op2
any_op2_desc->RenameInput(weight_dequantize_linear_op_out->Var()->Name(),
weight_dequantize_linear_op_x->Var()->Name());
any_op2_desc->Flush();
IR_NODE_LINK_TO(weight_dequantize_linear_op_x, any_op2);
GraphSafeRemoveNodes(graph, nodes2rm);
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_weight_dequant_linear_op_pass,
paddle::framework::ir::DeleteWeightQuantDequantLinearOpPass);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class DeleteWeightQuantDequantLinearOpPass : public FusePassBase {
public:
DeleteWeightQuantDequantLinearOpPass();
virtual ~DeleteWeightQuantDequantLinearOpPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -226,23 +226,34 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { ...@@ -226,23 +226,34 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
// For anakin subgraph int8 // For anakin subgraph int8
// When in anakin subgraph int8 mode, the pattern like "fake_quant + mul + // When in anakin subgraph int8 mode, the pattern like "fake_quant + mul +
// fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass // fake_dequant" can be detected by the quant_dequant_fuse_pass. This pass
// will add "input_scale", "weight_scale" which are extracted from // will add "input_scale" which are extracted from
// fake_quant op and fake_dequant op to mul op, and then delete the // fake_quant op and fake_dequant op to mul op, and then delete the
// fake_quant op and fake_dequant op in the graph. If the mul op has the // fake_quant op and fake_dequant op in the graph. If the mul op has the
// scale info, we should add those to the fused fc. // scale info, we should add those to the fused fc.
auto* mul_op_desc = mul->Op(); auto* mul_op_desc = mul->Op();
auto* elementwise_add_op_desc = elementwise_add->Op();
if (mul_op_desc->HasAttr("enable_int8")) { if (mul_op_desc->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", mul_op_desc->GetAttr("enable_int8")); desc.SetAttr("enable_int8", mul_op_desc->GetAttr("enable_int8"));
desc.SetAttr("Input_scale", mul_op_desc->GetAttr("X_scale"));
desc.SetAttr("weight_scale", mul_op_desc->GetAttr("weight_scale"));
if (mul_op_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", mul_op_desc->GetAttr("out_scale"));
auto elementwise_desc = elementwise_add->Op();
if (elementwise_desc->HasAttr("out_scale"))
desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale"));
} }
auto* elementwise_add_op_desc = elementwise_add->Op(); if (mul_op_desc->HasAttr("Input_scale")) {
desc.SetAttr("Input_scale", mul_op_desc->GetAttr("Input_scale"));
}
bool inscale_flag = false;
bool outscale_flag = false;
if (mul_op_desc->HasAttr("X")) {
desc.SetAttr("X", mul_op_desc->GetAttr("X"));
inscale_flag = true;
}
if (elementwise_add_op_desc->HasAttr("Out")) {
desc.SetAttr("Out", elementwise_add_op_desc->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag && outscale_flag);
// if we can find out_threshold in elementwise_add, then set it as the // if we can find out_threshold in elementwise_add, then set it as the
// out_thrshold of fc // out_thrshold of fc
auto out_threshold_attr = auto out_threshold_attr =
......
...@@ -298,8 +298,7 @@ void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -298,8 +298,7 @@ void GpuCpuMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) { if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold")); matmul_op->Op()->GetAttr("out_threshold"));
} }
...@@ -372,9 +371,7 @@ void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -372,9 +371,7 @@ void GpuCpuMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_v2_op->Op()->HasAttr("enable_int8")) { if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale",
matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold")); matmul_v2_op->Op()->GetAttr("out_threshold"));
} }
...@@ -451,8 +448,7 @@ void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -451,8 +448,7 @@ void GpuCpuMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
} }
if (matmul_v2_op->Op()->HasAttr("enable_int8")) { if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold")); matmul_v2_op->Op()->GetAttr("out_threshold"));
} }
...@@ -532,8 +528,7 @@ void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -532,8 +528,7 @@ void GpuCpuSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) { if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold")); matmul_op->Op()->GetAttr("out_threshold"));
} }
...@@ -677,8 +672,7 @@ void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -677,8 +672,7 @@ void GpuCpuReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) { if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold")); matmul_op->Op()->GetAttr("out_threshold"));
} }
...@@ -765,8 +759,7 @@ void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -765,8 +759,7 @@ void GpuCpuFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) { if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold")); matmul_op->Op()->GetAttr("out_threshold"));
} }
......
...@@ -2949,6 +2949,84 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() { ...@@ -2949,6 +2949,84 @@ void patterns::DeleteQuantDequantFilterOpPattern::operator()() {
any_op2->LinksFrom({quant_dequant_out}); any_op2->LinksFrom({quant_dequant_out});
} }
void patterns::DeleteWeightQuantDequantLinearOpPattern::operator()() {
auto weight_dequantize_linear_op_x =
pattern->NewNode(weight_dequantize_linear_op_x_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "X")
->assert_is_persistable_var();
auto weight_dequantize_linear_op_scale =
pattern->NewNode(weight_dequantize_linear_op_scale_repr())
->AsInput()
->assert_is_op_input("dequantize_linear", "Scale")
->assert_is_persistable_var();
auto weight_dequantize_linear_op =
pattern->NewNode(weight_dequantize_linear_op_repr())
->assert_is_op("dequantize_linear");
auto weight_dequantize_linear_op_out =
pattern->NewNode(weight_dequantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y");
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
weight_dequantize_linear_op
->LinksFrom(
{weight_dequantize_linear_op_x, weight_dequantize_linear_op_scale})
.LinksTo({weight_dequantize_linear_op_out});
any_op2->LinksFrom({weight_dequantize_linear_op_out});
}
void patterns::DeleteQuantDequantLinearOpPattern::operator()() {
auto quantize_linear_op_x = pattern->NewNode(quantize_linear_op_x_repr())
->AsInput()
->assert_is_op_input("quantize_linear", "X");
auto quantize_linear_op_scale =
pattern->NewNode(quantize_linear_op_scale_repr())
->AsInput()
->assert_is_op_input("quantize_linear", "Scale")
->assert_is_persistable_var();
auto quantize_linear_op = pattern->NewNode(quantize_linear_op_repr())
->assert_is_op("quantize_linear");
auto quantize_linear_op_out =
pattern->NewNode(quantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("quantize_linear", "Y")
->assert_is_op_input("dequantize_linear", "X")
->assert_var_not_persistable();
// Can not add this node. Todo: Wangzheee
/*
auto dequantize_linear_op_scale =
pattern->NewNode(dequantize_linear_op_scale_repr())
->assert_is_op_input("dequantize_linear", "Scale")
->AsIntermediate();
*/
auto dequantize_linear_op = pattern->NewNode(dequantize_linear_op_repr())
->assert_is_op("dequantize_linear");
auto dequantize_linear_op_out =
pattern->NewNode(dequantize_linear_op_out_repr())
->AsIntermediate()
->assert_is_op_output("dequantize_linear", "Y");
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
quantize_linear_op
->LinksFrom({quantize_linear_op_x, quantize_linear_op_scale})
.LinksTo({quantize_linear_op_out});
dequantize_linear_op->LinksFrom({quantize_linear_op_out})
.LinksTo({dequantize_linear_op_out});
any_op2->LinksFrom({dequantize_linear_op_out});
}
PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
const std::string &op_name, bool with_reshape_xshape, const std::string &op_name, bool with_reshape_xshape,
bool with_transpose_xshape) { bool with_transpose_xshape) {
...@@ -3311,25 +3389,14 @@ PDNode *patterns::LayerNorm::operator()() { ...@@ -3311,25 +3389,14 @@ PDNode *patterns::LayerNorm::operator()() {
return shift_out; return shift_out;
} }
// Add support int8 flag // Add support int8 flag and out_threshold
PDNode *patterns::AddSupportInt8::operator()() { PDNode *patterns::AddSupportInt8::operator()() {
auto prev_op = auto quant_op = pattern->NewNode(quant_op_repr())->assert_is_op();
pattern->NewNode(prev_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return node->Op()->HasAttr("out_threshold") ? true : false;
});
auto prev_out = pattern->NewNode(prev_out_repr())->assert_is_var();
auto quant_op =
pattern->NewNode(quant_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return node->Op()->HasAttr("out_threshold") ? true : false;
});
auto quant_out = auto quant_out =
pattern->NewNode(quant_out_repr())->assert_is_var()->AsOutput(); pattern->NewNode(quant_out_repr())
prev_op->LinksTo({prev_out}); ->assert_is_var()
prev_out->LinksTo({quant_op}); ->assert_more([&](Node *node) { return node->outputs.size() > 0; })
->AsOutput();
quant_op->LinksTo({quant_out}); quant_op->LinksTo({quant_out});
return quant_out; return quant_out;
} }
......
...@@ -1702,6 +1702,40 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase { ...@@ -1702,6 +1702,40 @@ struct DeleteQuantDequantFilterOpPattern : public PatternBase {
PATTERN_DECL_NODE(any_op2); PATTERN_DECL_NODE(any_op2);
}; };
struct DeleteWeightQuantDequantLinearOpPattern : public PatternBase {
DeleteWeightQuantDequantLinearOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope,
"delete_weight_quant_dequant_linear_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(weight_dequantize_linear_op_x);
PATTERN_DECL_NODE(weight_dequantize_linear_op_scale);
PATTERN_DECL_NODE(weight_dequantize_linear_op);
PATTERN_DECL_NODE(weight_dequantize_linear_op_out);
PATTERN_DECL_NODE(any_op2);
};
struct DeleteQuantDequantLinearOpPattern : public PatternBase {
DeleteQuantDequantLinearOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope,
"delete_quant_dequant_linear_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(quantize_linear_op_x);
PATTERN_DECL_NODE(quantize_linear_op_scale);
PATTERN_DECL_NODE(quantize_linear_op);
PATTERN_DECL_NODE(quantize_linear_op_out);
PATTERN_DECL_NODE(dequantize_linear_op);
// PATTERN_DECL_NODE(dequantize_linear_op_scale); // Can not add this node.
// Todo: Wangzheee
PATTERN_DECL_NODE(dequantize_linear_op_out);
PATTERN_DECL_NODE(any_op2);
};
// Reshape + Transpose + Matmul // Reshape + Transpose + Matmul
// named nodes: // named nodes:
// reshape_op, reshape_out, reshape_xshape, // reshape_op, reshape_out, reshape_xshape,
...@@ -1887,8 +1921,6 @@ struct AddSupportInt8 : public PatternBase { ...@@ -1887,8 +1921,6 @@ struct AddSupportInt8 : public PatternBase {
: PatternBase(pattern, name_scope, "Add_support_int8") {} : PatternBase(pattern, name_scope, "Add_support_int8") {}
PDNode* operator()(); PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(prev_out);
PATTERN_DECL_NODE(quant_op); PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(quant_out); PATTERN_DECL_NODE(quant_out);
}; };
......
...@@ -862,43 +862,30 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph, ...@@ -862,43 +862,30 @@ int MultiHeadMatmulV2FusePass::BuildFusionV2(Graph* graph,
multihead_op_desc.SetAttr("head_number", head_number); multihead_op_desc.SetAttr("head_number", head_number);
auto* mul0_op_desc = mul0->Op(); auto* mul0_op_desc = mul0->Op();
auto* mul1_op_desc = mul1->Op();
auto* mul2_op_desc = mul2->Op(); // all mul op has same input.
if (mul0_op_desc->HasAttr("enable_int8")) { if (multihead_op_desc.HasAttr("Input_scale")) {
multihead_op_desc.SetAttr("enable_int8",
mul0_op_desc->GetAttr("enable_int8"));
// all mul op has same input.
multihead_op_desc.SetAttr("Input_scale", multihead_op_desc.SetAttr("Input_scale",
mul0_op_desc->GetAttr("X_scale")); mul0_op_desc->GetAttr("Input_scale"));
auto weight_scale0 = BOOST_GET_CONST( }
std::vector<float>, mul0_op_desc->GetAttr("weight_scale")); auto* add0_op_desc = eltadd0->Op();
auto weight_scale1 = BOOST_GET_CONST( auto* add1_op_desc = eltadd1->Op();
std::vector<float>, mul1_op_desc->GetAttr("weight_scale")); auto* add2_op_desc = eltadd2->Op();
auto weight_scale2 = BOOST_GET_CONST( if (add0_op_desc->HasAttr("out_threshold")) {
std::vector<float>, mul2_op_desc->GetAttr("weight_scale")); auto out_scale0 =
auto weight_max = std::max(weight_scale0, weight_scale1); BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold"));
weight_max = std::max(weight_max, weight_scale2); auto out_scale1 =
multihead_op_desc.SetAttr("weight_scale", weight_max); BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold"));
auto out_scale2 =
auto* add0_op_desc = eltadd0->Op(); BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold"));
auto* add1_op_desc = eltadd1->Op(); auto out_scale_max = std::max(out_scale0, out_scale1);
auto* add2_op_desc = eltadd2->Op(); out_scale_max = std::max(out_scale_max, out_scale2);
if (add0_op_desc->HasAttr("out_threshold")) { multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max);
auto out_scale0 =
BOOST_GET_CONST(float, add0_op_desc->GetAttr("out_threshold"));
auto out_scale1 =
BOOST_GET_CONST(float, add1_op_desc->GetAttr("out_threshold"));
auto out_scale2 =
BOOST_GET_CONST(float, add2_op_desc->GetAttr("out_threshold"));
auto out_scale_max = std::max(out_scale0, out_scale1);
out_scale_max = std::max(out_scale_max, out_scale2);
multihead_op_desc.SetAttr("fc_out_threshold", out_scale_max);
}
} }
auto* softmax_qk_op_desc = softmax_qk->Op(); auto* softmax_qk_op_desc = softmax_qk->Op();
auto* matmul_qk_op_desc = matmul_qk->Op(); auto* matmul_qk_op_desc = matmul_qk->Op();
if (matmul_qk_op_desc->HasAttr("X_scale")) { if (matmul_qk_op_desc->HasAttr("Input_scale")) {
multihead_op_desc.SetAttr("qkv2context_plugin_int8", true); multihead_op_desc.SetAttr("qkv2context_plugin_int8", true);
if (softmax_qk_op_desc->HasAttr("out_threshold")) { if (softmax_qk_op_desc->HasAttr("out_threshold")) {
auto qkv_plugin_scale = BOOST_GET_CONST( auto qkv_plugin_scale = BOOST_GET_CONST(
......
...@@ -341,7 +341,6 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -341,7 +341,6 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
Node* output_scale = subgraph.at(pattern.GetPDNode("output_scale_node")); Node* output_scale = subgraph.at(pattern.GetPDNode("output_scale_node"));
Node* output_act = subgraph.at(pattern.GetPDNode("output_act_node")); Node* output_act = subgraph.at(pattern.GetPDNode("output_act_node"));
int bit_length = BOOST_GET_CONST(int, quant->Op()->GetAttr("bit_length")); int bit_length = BOOST_GET_CONST(int, quant->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
// Get input scale from tensor // Get input scale from tensor
std::string input_scale_var_name = quant->Op()->Input("InScale").front(); std::string input_scale_var_name = quant->Op()->Input("InScale").front();
...@@ -356,7 +355,7 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -356,7 +355,7 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
"Input scale tensor's place should be CPU.")); "Input scale tensor's place should be CPU."));
const float* input_scale_data = input_scale_tensor.data<float>(); const float* input_scale_data = input_scale_tensor.data<float>();
float in_scale = input_scale_data[0]; float in_scale = input_scale_data[0];
float scale_value = in_scale / range; float scale_value = in_scale;
// Set input scale in attr, and relink nodes // Set input scale in attr, and relink nodes
std::string input_act_name = input_act->Var()->Name(); std::string input_act_name = input_act->Var()->Name();
...@@ -369,11 +368,10 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope, ...@@ -369,11 +368,10 @@ void QuantDequantFusePass::DeleteQuant(ir::Graph* graph, Scope* scope,
quantized_op_type == "conv2d_fusion" || quantized_op_type == "conv2d_fusion" ||
quantized_op_type == "depthwise_conv2d" || quantized_op_type == "depthwise_conv2d" ||
quantized_op_type == "fc" || quantized_op_type == "fc" ||
quantized_op_type == "conv2d_transpose") { quantized_op_type == "conv2d_transpose" ||
quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
op_desc->SetAttr("Input_scale", scale_value); op_desc->SetAttr("Input_scale", scale_value);
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "matmul_v2") {
op_desc->SetAttr("X_scale", scale_value);
} else { } else {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported quantized op type %s.", quantized_op_type)); "Unsupported quantized op type %s.", quantized_op_type));
...@@ -619,7 +617,6 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope, ...@@ -619,7 +617,6 @@ void QuantDequantFusePass::FuseDequant(ir::Graph* graph, Scope* scope,
new_op_desc.SetInput("X", {new_input}); new_op_desc.SetInput("X", {new_input});
new_op_desc.SetOutput("Out", {new_output}); new_op_desc.SetOutput("Out", {new_output});
} }
new_op_desc.SetAttr("weight_scale", weight_scale);
new_op_desc.Flush(); new_op_desc.Flush();
auto* new_op = graph->CreateOpNode(&new_op_desc); auto* new_op = graph->CreateOpNode(&new_op_desc);
IR_NODE_LINK_TO(quantized_op_input_node, new_op); IR_NODE_LINK_TO(quantized_op_input_node, new_op);
......
...@@ -297,11 +297,24 @@ void TrtMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -297,11 +297,24 @@ void TrtMapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("transpose_Y", matmul_op->Op()->GetAttr("transpose_Y")); desc.SetAttr("transpose_Y", matmul_op->Op()->GetAttr("transpose_Y"));
if (matmul_op->Op()->HasAttr("enable_int8")) { if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold")); matmul_op->Op()->GetAttr("out_threshold"));
} }
bool inscale_flag = false;
bool outscale_flag = false;
if (matmul_op->Op()->HasAttr("X")) {
desc.SetAttr("X", matmul_op->Op()->GetAttr("X"));
inscale_flag = true;
}
if (matmul_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag && outscale_flag);
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node);
...@@ -370,12 +383,23 @@ void TrtMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -370,12 +383,23 @@ void TrtMapMatmulV2ToMulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("transpose_Y", matmul_v2_op->Op()->GetAttr("trans_y")); desc.SetAttr("transpose_Y", matmul_v2_op->Op()->GetAttr("trans_y"));
if (matmul_v2_op->Op()->HasAttr("enable_int8")) { if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale",
matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold")); matmul_v2_op->Op()->GetAttr("out_threshold"));
} }
bool inscale_flag = false;
bool outscale_flag = false;
if (matmul_v2_op->Op()->HasAttr("X")) {
desc.SetAttr("X", matmul_v2_op->Op()->GetAttr("X"));
inscale_flag = true;
}
if (matmul_v2_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_v2_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag && outscale_flag);
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_v2_in_x, mul_node); IR_NODE_LINK_TO(matmul_v2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_v2_in_y, mul_node); IR_NODE_LINK_TO(matmul_v2_in_y, mul_node);
...@@ -448,11 +472,23 @@ void TrtMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -448,11 +472,23 @@ void TrtMapMatmulV2ToMatmulPass::ApplyImpl(ir::Graph* graph) const {
} }
if (matmul_v2_op->Op()->HasAttr("enable_int8")) { if (matmul_v2_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_v2_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_v2_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_v2_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale", matmul_v2_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_v2_op->Op()->GetAttr("out_threshold")); matmul_v2_op->Op()->GetAttr("out_threshold"));
} }
bool inscale_flag = false;
bool outscale_flag = false;
if (matmul_v2_op->Op()->HasAttr("X")) {
desc.SetAttr("X", matmul_v2_op->Op()->GetAttr("X"));
inscale_flag = true;
}
if (matmul_v2_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_v2_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag && outscale_flag);
auto matmul_node = g->CreateOpNode(&desc); auto matmul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_v2_in_x, matmul_node); IR_NODE_LINK_TO(matmul_v2_in_x, matmul_node);
IR_NODE_LINK_TO(matmul_v2_in_y, matmul_node); IR_NODE_LINK_TO(matmul_v2_in_y, matmul_node);
...@@ -530,11 +566,24 @@ void TrtSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -530,11 +566,24 @@ void TrtSqueeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) { if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold")); matmul_op->Op()->GetAttr("out_threshold"));
} }
bool inscale_flag_x = false;
bool outscale_flag = false;
if (squeeze2_op->Op()->HasAttr("X")) {
desc.SetAttr("X", squeeze2_op->Op()->GetAttr("X"));
inscale_flag_x = true;
}
if (matmul_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag_x && outscale_flag);
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(squeeze2_in_x, mul_node); IR_NODE_LINK_TO(squeeze2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node);
...@@ -675,11 +724,24 @@ void TrtReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -675,11 +724,24 @@ void TrtReshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) { if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold")); matmul_op->Op()->GetAttr("out_threshold"));
} }
bool inscale_flag_x = false;
bool outscale_flag = false;
if (reshape2_op->Op()->HasAttr("X")) {
desc.SetAttr("X", reshape2_op->Op()->GetAttr("X"));
inscale_flag_x = true;
}
if (matmul_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag_x && outscale_flag);
if (!IsCompat(desc)) { if (!IsCompat(desc)) {
LOG(WARNING) LOG(WARNING)
<< "TrtReshape2MatmulFusePass in out mul op compat failed."; << "TrtReshape2MatmulFusePass in out mul op compat failed.";
...@@ -763,11 +825,24 @@ void TrtFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -763,11 +825,24 @@ void TrtFlatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) { if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8")); desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale")); desc.SetAttr("Input_scale", matmul_op->Op()->GetAttr("Input_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
desc.SetAttr("out_threshold", desc.SetAttr("out_threshold",
matmul_op->Op()->GetAttr("out_threshold")); matmul_op->Op()->GetAttr("out_threshold"));
} }
bool inscale_flag_x = false;
bool outscale_flag = false;
if (flatten2_op->Op()->HasAttr("X")) {
desc.SetAttr("X", flatten2_op->Op()->GetAttr("X"));
inscale_flag_x = true;
}
if (matmul_op->Op()->HasAttr("Out")) {
desc.SetAttr("Out", matmul_op->Op()->GetAttr("Out"));
outscale_flag = true;
}
desc.SetAttr("support_int8", inscale_flag_x && outscale_flag);
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(flatten2_in_x, mul_node); IR_NODE_LINK_TO(flatten2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node);
......
...@@ -76,10 +76,13 @@ void PaddlePassBuilder::ClearPasses() { passes_.clear(); } ...@@ -76,10 +76,13 @@ void PaddlePassBuilder::ClearPasses() { passes_.clear(); }
const std::vector<std::string> kTRTSubgraphPasses({ const std::vector<std::string> kTRTSubgraphPasses({
"adaptive_pool2d_convert_global_pass", "adaptive_pool2d_convert_global_pass",
"shuffle_channel_detect_pass", // "shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", // "quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", // "delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", // "delete_quant_dequant_filter_op_pass", //
"delete_weight_dequant_linear_op_pass", //
"delete_quant_dequant_linear_op_pass", //
"add_support_int8_pass", //
// "fc_fuse_pass", // // "fc_fuse_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", //
...@@ -98,9 +101,8 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -98,9 +101,8 @@ const std::vector<std::string> kTRTSubgraphPasses({
"trt_map_matmul_to_mul_pass", // "trt_map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
"add_support_int8_pass", "tensorrt_subgraph_pass", //
"tensorrt_subgraph_pass", // "conv_bn_fuse_pass", //
"conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7 // guaranteed at least v7
// cudnn8.0 has memory leak problem in conv + eltwise + act, so we // cudnn8.0 has memory leak problem in conv + eltwise + act, so we
......
...@@ -68,12 +68,6 @@ class ActivationOpConverter : public OpConverter { ...@@ -68,12 +68,6 @@ class ActivationOpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode); RreplenishLayerAndOutput(layer, op_type_, {output_name}, test_mode);
if (op_desc.HasAttr("out_scale")) {
#if IS_TRT_VERSION_GE(5130)
float out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_scale"));
engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale);
#endif
}
} }
protected: protected:
......
...@@ -49,11 +49,11 @@ class AffineChannelOpConverter : public OpConverter { ...@@ -49,11 +49,11 @@ class AffineChannelOpConverter : public OpConverter {
auto* scale_v = scope.FindVar(scale_name); auto* scale_v = scope.FindVar(scale_name);
auto* scale_t = scale_v->GetMutable<framework::LoDTensor>(); auto* scale_t = scale_v->GetMutable<framework::LoDTensor>();
float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t, false); float* scale_ptr = engine_->GetWeightCPUData(scale_name, scale_t);
auto* bias_v = scope.FindVar(bias_name); auto* bias_v = scope.FindVar(bias_name);
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>(); auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t, false); float* bias_ptr = engine_->GetWeightCPUData(bias_name, bias_t);
// tensorrt scalend layer only support spatial dims >= 2, // tensorrt scalend layer only support spatial dims >= 2,
// so nhwc is not availabe (spatial dims == 0) // so nhwc is not availabe (spatial dims == 0)
......
...@@ -49,18 +49,11 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -49,18 +49,11 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
if (enable_int8) { if (enable_int8) {
#if IS_TRT_VERSION_GE(5000) #if IS_TRT_VERSION_GE(5000)
float in_scale = float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t,
true, weight_scale);
engine->SetTensorDynamicRange(X, in_scale); engine->SetTensorDynamicRange(X, in_scale);
#endif #endif
} else {
weight_data =
engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, false);
} }
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t);
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 4UL, PADDLE_ENFORCE_EQ(Y_t->dims().size(), 4UL,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -115,7 +108,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -115,7 +108,7 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
auto* bias_tensor = scope.GetVar(op_desc.Input("Bias").front()); auto* bias_tensor = scope.GetVar(op_desc.Input("Bias").front());
auto* bias_tensor_data = bias_tensor->GetMutable<framework::LoDTensor>(); auto* bias_tensor_data = bias_tensor->GetMutable<framework::LoDTensor>();
bias_data = engine->GetWeightCPUData(op_desc.Input("Bias").front(), bias_data = engine->GetWeightCPUData(op_desc.Input("Bias").front(),
bias_tensor_data, false); bias_tensor_data);
bias_size = static_cast<size_t>(bias_tensor_data->numel()); bias_size = static_cast<size_t>(bias_tensor_data->numel());
} }
......
...@@ -48,17 +48,10 @@ void ConvertConv3d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -48,17 +48,10 @@ void ConvertConv3d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
bool enable_int8 = op_desc.HasAttr("enable_int8"); bool enable_int8 = op_desc.HasAttr("enable_int8");
if (enable_int8) { if (enable_int8) {
float in_scale = float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127;
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t,
true, weight_scale);
engine->SetTensorDynamicRange(X, in_scale); engine->SetTensorDynamicRange(X, in_scale);
} else {
weight_data =
engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t, false);
} }
weight_data = engine->GetWeightCPUData(op_desc.Input("Filter").front(), Y_t);
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 5UL, PADDLE_ENFORCE_EQ(Y_t->dims().size(), 5UL,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
......
...@@ -47,8 +47,7 @@ class DeformableConvOpConverter : public OpConverter { ...@@ -47,8 +47,7 @@ class DeformableConvOpConverter : public OpConverter {
auto* filter_var = scope.FindVar(filter_name); auto* filter_var = scope.FindVar(filter_name);
auto* filter_tensor = filter_var->GetMutable<framework::LoDTensor>(); auto* filter_tensor = filter_var->GetMutable<framework::LoDTensor>();
float* filter_data = float* filter_data = engine_->GetWeightCPUData(filter_name, filter_tensor);
engine_->GetWeightCPUData(filter_name, filter_tensor, false);
const int c_o = filter_tensor->dims()[0]; const int c_o = filter_tensor->dims()[0];
const int c_i = filter_tensor->dims()[1]; const int c_i = filter_tensor->dims()[1];
......
...@@ -51,8 +51,7 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -51,8 +51,7 @@ class ElementwiseWeightOpConverter : public OpConverter {
auto* Y_t = Y_v->GetMutable<framework::LoDTensor>(); auto* Y_t = Y_v->GetMutable<framework::LoDTensor>();
float* weight_data = nullptr; float* weight_data = nullptr;
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
weight_data = weight_data = engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t);
engine_->GetWeightCPUData(op_desc.Input("Y").front(), Y_t, false);
nvinfer1::Dims dims_x = X->getDimensions(); nvinfer1::Dims dims_x = X->getDimensions();
auto regist_eltwise_weight = [&](nvinfer1::ScaleMode scale_mode) { auto regist_eltwise_weight = [&](nvinfer1::ScaleMode scale_mode) {
...@@ -112,13 +111,6 @@ class ElementwiseWeightOpConverter : public OpConverter { ...@@ -112,13 +111,6 @@ class ElementwiseWeightOpConverter : public OpConverter {
RreplenishLayerAndOutput(layer, "elementwise_" + op_type_, RreplenishLayerAndOutput(layer, "elementwise_" + op_type_,
{output_name}, test_mode); {output_name}, test_mode);
} }
if (op_desc.HasAttr("enable_int8")) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("X_scale"));
float x_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale"));
engine_->SetTensorDynamicRange(X, x_scale);
#endif
}
}; };
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
...@@ -222,16 +214,6 @@ class ElementwiseTensorOpConverter : public OpConverter { ...@@ -222,16 +214,6 @@ class ElementwiseTensorOpConverter : public OpConverter {
auto common_func = [&](nvinfer1::ILayer* layer) { auto common_func = [&](nvinfer1::ILayer* layer) {
RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "elementwise", {output_name}, test_mode);
if (op_desc.HasAttr("enable_int8")) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("X_scale"));
CHECK(op_desc.HasAttr("Y_scale"));
float x_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale"));
float y_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Y_scale"));
engine_->SetTensorDynamicRange(X, x_scale);
engine_->SetTensorDynamicRange(Y, y_scale);
#endif
}
}; };
if (dims_x.nbDims == dims_y.nbDims) { if (dims_x.nbDims == dims_y.nbDims) {
......
...@@ -77,7 +77,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -77,7 +77,7 @@ class EmbEltwiseLayerNormOpConverter : public OpConverter {
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>(); auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims(); (*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
return temp_data; return temp_data;
}; };
......
...@@ -113,22 +113,20 @@ class FcOpConverter : public OpConverter { ...@@ -113,22 +113,20 @@ class FcOpConverter : public OpConverter {
// assigned from CPU memory, which can't be avoided. // assigned from CPU memory, which can't be avoided.
float* weight_data = nullptr; float* weight_data = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8"); bool enable_int8 = op_desc.HasAttr("enable_int8");
float in_scale = 0.; bool support_int8 = false;
if (enable_int8) { if (op_desc.HasAttr("support_int8")) {
#if IS_TRT_VERSION_GE(5000) support_int8 = BOOST_GET_CONST(bool, op_desc.GetAttr("support_int8"));
CHECK(op_desc.HasAttr(i_name + "_scale")); }
in_scale = float in_scale = 0;
BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")) * 127; if (enable_int8 || support_int8) {
auto weight_scale = if (enable_int8) {
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale")); in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), } else {
Y_t, true, weight_scale); in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X"));
}
engine_->SetTensorDynamicRange(X, in_scale); engine_->SetTensorDynamicRange(X, in_scale);
#endif
} else {
weight_data =
engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t, false);
} }
weight_data = engine_->GetWeightCPUData(op_desc.Input(w_name).front(), Y_t);
PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL, PADDLE_ENFORCE_EQ(Y_t->dims().size(), 2UL,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -148,14 +146,18 @@ class FcOpConverter : public OpConverter { ...@@ -148,14 +146,18 @@ class FcOpConverter : public OpConverter {
auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output, auto regist_fc = [&](nvinfer1::ITensor* inputs, int n_output,
TensorRTEngine::Weight& weight, TensorRTEngine::Weight& weight,
TensorRTEngine::Weight& bias) { TensorRTEngine::Weight& bias) {
if (enable_int8) { if (enable_int8 || support_int8) {
// add conv layer // add conv layer
PADDLE_ENFORCE_EQ( float out_scale = 0;
op_desc.HasAttr("out_threshold"), true, if (enable_int8) {
platform::errors::InvalidArgument( PADDLE_ENFORCE_EQ(
"must have out threshold in fc layers in int8 mode")); op_desc.HasAttr("out_threshold"), true,
float out_scale = platform::errors::InvalidArgument(
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); "must have out threshold in fc layers in int8 mode"));
out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
} else {
out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out"));
}
nvinfer1::DimsHW nv_ksize(1, 1); nvinfer1::DimsHW nv_ksize(1, 1);
auto* fc_layer_int8 = auto* fc_layer_int8 =
TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output, TRT_ENGINE_ADD_LAYER(engine_, Convolution, *inputs, n_output,
...@@ -235,8 +237,7 @@ class FcOpConverter : public OpConverter { ...@@ -235,8 +237,7 @@ class FcOpConverter : public OpConverter {
if (with_bias) { if (with_bias) {
auto* b_v = scope.GetVar(op_desc.Input("Bias").front()); auto* b_v = scope.GetVar(op_desc.Input("Bias").front());
auto* b_t = b_v->GetMutable<framework::LoDTensor>(); auto* b_t = b_v->GetMutable<framework::LoDTensor>();
bias_data = bias_data = engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t);
engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t, false);
bias_num = b_t->numel(); bias_num = b_t->numel();
} }
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
...@@ -251,7 +252,7 @@ class FcOpConverter : public OpConverter { ...@@ -251,7 +252,7 @@ class FcOpConverter : public OpConverter {
// not add Shuffle layer in ernie's multihead. // not add Shuffle layer in ernie's multihead.
if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 && if (engine_->use_oss() && engine_->with_ernie() && x_dim.nbDims == 4 &&
x_dim.d[3] == 1 && x_num_col_dims == 2) { x_dim.d[3] == 1 && x_num_col_dims == 2) {
if (enable_int8) { if (enable_int8 || support_int8) {
// add conv1x1 layer // add conv1x1 layer
nvinfer1::DimsHW nv_ksize(1, 1); nvinfer1::DimsHW nv_ksize(1, 1);
auto* fc_layer_int8 = auto* fc_layer_int8 =
...@@ -265,8 +266,13 @@ class FcOpConverter : public OpConverter { ...@@ -265,8 +266,13 @@ class FcOpConverter : public OpConverter {
op_desc.HasAttr("out_threshold"), true, op_desc.HasAttr("out_threshold"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"must have out threshold in fc layers in int8 mode")); "must have out threshold in fc layers in int8 mode"));
float out_scale = float out_scale = 0;
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); if (enable_int8) {
out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
} else {
out_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Out"));
}
engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0), engine_->SetTensorDynamicRange(fc_layer_int8->getOutput(0),
out_scale); out_scale);
nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER( nvinfer1::IActivationLayer* relu_layer_int8 = TRT_ENGINE_ADD_LAYER(
...@@ -308,7 +314,7 @@ class FcOpConverter : public OpConverter { ...@@ -308,7 +314,7 @@ class FcOpConverter : public OpConverter {
auto* reshape_before_fc_layer = auto* reshape_before_fc_layer =
reshape_before_fc(X, x_dim, x_num_col_dims, output_name); reshape_before_fc(X, x_dim, x_num_col_dims, output_name);
auto* reshape_itensor = reshape_before_fc_layer->getOutput(0); auto* reshape_itensor = reshape_before_fc_layer->getOutput(0);
if (enable_int8) { if (enable_int8 || support_int8) {
engine_->SetTensorDynamicRange(reshape_itensor, in_scale); engine_->SetTensorDynamicRange(reshape_itensor, in_scale);
} }
regist_fc(reshape_itensor, n_output, weight, bias); regist_fc(reshape_itensor, n_output, weight, bias);
......
...@@ -48,7 +48,7 @@ class GroupNormOpConverter : public OpConverter { ...@@ -48,7 +48,7 @@ class GroupNormOpConverter : public OpConverter {
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>(); auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims(); (*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
return temp_data; return temp_data;
}; };
......
...@@ -49,8 +49,8 @@ class LeakyReluOpConverter : public OpConverter { ...@@ -49,8 +49,8 @@ class LeakyReluOpConverter : public OpConverter {
bool enable_int8 = op_desc.HasAttr("enable_int8"); bool enable_int8 = op_desc.HasAttr("enable_int8");
if (enable_int8) { if (enable_int8) {
CHECK(op_desc.HasAttr("X_scale")); CHECK(op_desc.HasAttr("Input_scale"));
float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); float in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine_->SetTensorDynamicRange(input, in_scale); engine_->SetTensorDynamicRange(input, in_scale);
} }
#else #else
......
...@@ -64,7 +64,9 @@ class MatMulOpConverter : public OpConverter { ...@@ -64,7 +64,9 @@ class MatMulOpConverter : public OpConverter {
: nvinfer1::MatrixOperation::kNONE; : nvinfer1::MatrixOperation::kNONE;
if (op_desc.HasAttr("support_int8") && if (op_desc.HasAttr("support_int8") &&
engine_->precision() == AnalysisConfig::Precision::kInt8) { BOOST_GET_CONST(bool, op_desc.GetAttr("support_int8")) &&
engine_->precision() == AnalysisConfig::Precision::kInt8 &&
platform::GetGPUComputeCapability(0) >= 75) {
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
VLOG(3) << "Convert a fluid matmul_op_int8_dynamic to TensorRT " VLOG(3) << "Convert a fluid matmul_op_int8_dynamic to TensorRT "
"MatmulPluginLayer"; "MatmulPluginLayer";
......
...@@ -40,22 +40,16 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -40,22 +40,16 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto* bias_t = bias_v->GetMutable<framework::LoDTensor>(); auto* bias_t = bias_v->GetMutable<framework::LoDTensor>();
float* weight_data = nullptr; float* weight_data = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8");
bool qkv2context_plugin_int8 = op_desc.HasAttr("qkv2context_plugin_int8"); bool qkv2context_plugin_int8 = op_desc.HasAttr("qkv2context_plugin_int8");
float in_scale = 0.; float in_scale = 0.;
if (enable_int8) { if (op_desc.HasAttr("Input_scale")) {
in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale")) * 127; in_scale = BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
weight_data =
engine_->GetWeightCPUData(weight_name, weight_t, true, weight_scale);
engine_->SetTensorDynamicRange(input, in_scale); engine_->SetTensorDynamicRange(input, in_scale);
} else {
weight_data = engine_->GetWeightCPUData(weight_name, weight_t, false);
} }
weight_data = engine_->GetWeightCPUData(weight_name, weight_t);
float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t, false); float* bias_data = engine_->GetWeightCPUData(bias_name, bias_t);
std::vector<float> weight_data_tmp; std::vector<float> weight_data_tmp;
weight_data_tmp.reserve(weight_t->numel()); weight_data_tmp.reserve(weight_t->numel());
memcpy(weight_data_tmp.data(), weight_data, memcpy(weight_data_tmp.data(), weight_data,
...@@ -85,6 +79,10 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -85,6 +79,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
if (engine_->with_dynamic_shape()) { if (engine_->with_dynamic_shape()) {
if (engine_->use_oss()) { if (engine_->use_oss()) {
if (engine_->precision() == AnalysisConfig::Precision::kFloat32) {
PADDLE_THROW(platform::errors::Fatal(
"use use_oss must be int8 or half, not float32."));
}
nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT, nvinfer1::Weights weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data), static_cast<void*>(weight_data),
static_cast<int32_t>(weight_t->numel())}; static_cast<int32_t>(weight_t->numel())};
...@@ -93,7 +91,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -93,7 +91,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
static_cast<int32_t>(bias_t->numel())}; static_cast<int32_t>(bias_t->numel())};
if (engine_->with_interleaved()) { if (engine_->with_interleaved()) {
VLOG(4) << "fused multihead_matmul op: use_oss and with_interleaved"; VLOG(4) << "fused multihead_matmul op: use_oss and with_interleaved";
if (!enable_int8) { if (!op_desc.HasAttr("Input_scale")) {
PADDLE_THROW( PADDLE_THROW(
platform::errors::Fatal("use with_interleaved must be int8.")); platform::errors::Fatal("use with_interleaved must be int8."));
} }
...@@ -213,7 +211,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -213,7 +211,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
nvinfer1::ILayer* fc_layer = nullptr; nvinfer1::ILayer* fc_layer = nullptr;
float dp_probs = 1.0 / 127.0; float dp_probs = 1.0 / 127.0;
if (enable_int8) { if (op_desc.HasAttr("Input_scale")) {
nvinfer1::DimsHW nv_ksize(1, 1); nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n, fc_layer = TRT_ENGINE_ADD_LAYER(engine_, Convolution, *input, n,
nv_ksize, weight, bias); nv_ksize, weight, bias);
...@@ -222,7 +220,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -222,7 +220,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
weight, bias); weight, bias);
} }
if (enable_int8) { if (op_desc.HasAttr("fc_out_threshold")) {
PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), true, PADDLE_ENFORCE_EQ(op_desc.HasAttr("fc_out_threshold"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"must have out threshold in multihead layers " "must have out threshold in multihead layers "
...@@ -241,14 +239,10 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -241,14 +239,10 @@ class MultiheadMatMulOpConverter : public OpConverter {
auto creator = GetPluginRegistry()->getPluginCreator( auto creator = GetPluginRegistry()->getPluginCreator(
"CustomQKVToContextPluginDynamic", "2"); "CustomQKVToContextPluginDynamic", "2");
assert(creator != nullptr); assert(creator != nullptr);
int type = static_cast<int>((engine_->WithFp16() == 1) int type = static_cast<int>(nvinfer1::DataType::kHALF);
? nvinfer1::DataType::kHALF if (qkv2context_plugin_int8 &&
: nvinfer1::DataType::kFLOAT); (engine_->precision() == AnalysisConfig::Precision::kInt8)) {
if (enable_int8) { type = static_cast<int>(nvinfer1::DataType::kINT8);
type = static_cast<int>(nvinfer1::DataType::kHALF);
if (qkv2context_plugin_int8) {
type = static_cast<int>(nvinfer1::DataType::kINT8);
}
} }
bool has_mask = true; bool has_mask = true;
int var_seqlen = 1; int var_seqlen = 1;
...@@ -335,7 +329,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -335,7 +329,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
reshape_before_fc_dim.d[4] = 1; reshape_before_fc_dim.d[4] = 1;
auto* reshape_before_fc_layer = auto* reshape_before_fc_layer =
TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input); TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *input);
if (enable_int8) { if (op_desc.HasAttr("Input_scale")) {
engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0), engine_->SetTensorDynamicRange(reshape_before_fc_layer->getOutput(0),
in_scale); in_scale);
} }
...@@ -346,7 +340,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -346,7 +340,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
// add layer fc // add layer fc
nvinfer1::ILayer* fc_layer = nullptr; nvinfer1::ILayer* fc_layer = nullptr;
if (enable_int8) { if (op_desc.HasAttr("Input_scale")) {
nvinfer1::DimsHW nv_ksize(1, 1); nvinfer1::DimsHW nv_ksize(1, 1);
fc_layer = TRT_ENGINE_ADD_LAYER( fc_layer = TRT_ENGINE_ADD_LAYER(
engine_, Convolution, *reshape_before_fc_layer->getOutput(0), n, engine_, Convolution, *reshape_before_fc_layer->getOutput(0), n,
...@@ -357,7 +351,7 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -357,7 +351,7 @@ class MultiheadMatMulOpConverter : public OpConverter {
n, weight.get(), bias.get()); n, weight.get(), bias.get());
} }
if (enable_int8) { if (op_desc.HasAttr("fc_out_threshold")) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
op_desc.HasAttr("fc_out_threshold"), true, op_desc.HasAttr("fc_out_threshold"), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -382,8 +376,8 @@ class MultiheadMatMulOpConverter : public OpConverter { ...@@ -382,8 +376,8 @@ class MultiheadMatMulOpConverter : public OpConverter {
bool with_fp16 = bool with_fp16 =
engine_->WithFp16() && !engine_->disable_trt_plugin_fp16(); engine_->WithFp16() && !engine_->disable_trt_plugin_fp16();
if (enable_int8) { if (engine_->precision() == AnalysisConfig::Precision::kInt8) {
with_fp16 = 1; with_fp16 = true;
} }
plugin::DynamicPluginTensorRT* plugin = plugin::DynamicPluginTensorRT* plugin =
new plugin::QkvToContextPluginDynamic(hidden_in, head_number, new plugin::QkvToContextPluginDynamic(hidden_in, head_number,
......
...@@ -145,42 +145,68 @@ class OpConverter { ...@@ -145,42 +145,68 @@ class OpConverter {
(*it)(op, scope, test_mode); (*it)(op, scope, test_mode);
size_t output_num = op_desc.OutputNames().size(); size_t output_num = op_desc.OutputNames().size();
if (output_num == 1) { // The number of output is 1 // only one out settensordynamicRange
if (op_desc.HasAttr("out_threshold")) { if (op_desc.HasAttr("out_threshold")) {
float out_scale = float out_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold")); BOOST_GET_CONST(float, op_desc.GetAttr("out_threshold"));
std::string output_name = ""; std::string output_name = "";
if (op_desc.HasOutput("Output")) { if (op_desc.HasOutput("Output")) {
output_name = op_desc.Output("Output").front(); output_name = op_desc.Output("Output").front();
} else if (op_desc.HasOutput("Out")) { } else if (op_desc.HasOutput("Out")) {
output_name = op_desc.Output("Out").front(); output_name = op_desc.Output("Out").front();
} else if (op_desc.HasOutput("Y")) { } else if (op_desc.HasOutput("Y")) {
output_name = op_desc.Output("Y").front(); output_name = op_desc.Output("Y").front();
} else { } else {
PADDLE_THROW( PADDLE_THROW(
platform::errors::NotFound("Op %s has out threshold but doesn't " platform::errors::NotFound("Op %s has out threshold but doesn't "
"have an output named \"Output\", " "have an output named \"Output\", "
"\"Out\" or \"Y\".", "\"Out\" or \"Y\".",
op_desc.Type())); op_desc.Type()));
} }
auto* output_itensor = engine->GetITensor(output_name);
engine->SetTensorDynamicRange(output_itensor, out_scale);
VLOG(1) << "Set out scale = " << out_scale << " for tensor "
<< output_name << ".";
}
// outs settensordynamicRange
for (size_t i = 0; i < output_num; ++i) {
if (op_desc.HasAttr("out_" + std::to_string(i) + "_threshold")) {
float out_scale = BOOST_GET_CONST(
float, op_desc.GetAttr("out_" + std::to_string(i) + "_threshold"));
std::string output_name =
op_desc.Output(op_desc.OutputNames()[i]).front();
auto* output_itensor = engine->GetITensor(output_name); auto* output_itensor = engine->GetITensor(output_name);
engine->SetTensorDynamicRange(output_itensor, out_scale); engine->SetTensorDynamicRange(output_itensor, out_scale);
VLOG(1) << "Set out scale = " << out_scale << " for tensor " VLOG(1) << "Set out scale = " << out_scale << " for tensor "
<< output_name << "."; << output_name << ".";
} }
} else if (output_num > 1) { // The number of outputs greater than 1 }
for (size_t i = 0; i < output_num; ++i) {
if (op_desc.HasAttr("out_" + std::to_string(i) + "_threshold")) { // quant_dequant_linear support for paddle trt
float out_scale = BOOST_GET_CONST(
float, std::vector<std::string> inputs_name = op_desc.InputNames();
op_desc.GetAttr("out_" + std::to_string(i) + "_threshold")); std::vector<std::string> outputs_name = op_desc.OutputNames();
std::string output_name =
op_desc.Output(op_desc.OutputNames()[i]).front(); for (size_t i = 0; i < inputs_name.size(); i++) {
auto* output_itensor = engine->GetITensor(output_name); if (op_desc.HasAttr(inputs_name[i])) {
engine->SetTensorDynamicRange(output_itensor, out_scale); std::string input_tensor_name = op_desc.Input(inputs_name[i])[0];
VLOG(1) << "Set out scale = " << out_scale << " for tensor " auto* input_itensor = engine->GetITensor(input_tensor_name);
<< output_name << "."; float input_scale =
} BOOST_GET_CONST(float, op_desc.GetAttr(inputs_name[i]));
engine->SetTensorDynamicRange(input_itensor, input_scale);
VLOG(1) << "Set input tensor scale = " << input_scale
<< " for tensor: " << input_tensor_name << ".";
}
}
for (size_t i = 0; i < outputs_name.size(); i++) {
if (op_desc.HasAttr(outputs_name[i])) {
std::string output_tensor_name = op_desc.Output(outputs_name[i])[0];
auto* output_itensor = engine->GetITensor(output_tensor_name);
float output_scale =
BOOST_GET_CONST(float, op_desc.GetAttr(outputs_name[i]));
engine->SetTensorDynamicRange(output_itensor, output_scale);
VLOG(1) << "Set output tensor scale = " << output_scale
<< " for tensor: " << output_tensor_name << ".";
} }
} }
} }
......
...@@ -132,11 +132,10 @@ class Pool2dOpConverter : public OpConverter { ...@@ -132,11 +132,10 @@ class Pool2dOpConverter : public OpConverter {
} }
if (op_desc.HasAttr("enable_int8")) { if (op_desc.HasAttr("enable_int8")) {
#if IS_TRT_VERSION_GE(5000) CHECK(op_desc.HasAttr("Input_scale"));
CHECK(op_desc.HasAttr("X_scale")); float input_scale =
float input_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine_->SetTensorDynamicRange(input1, input_scale); engine_->SetTensorDynamicRange(input1, input_scale);
#endif
} }
std::vector<int> real_paddings = paddings; std::vector<int> real_paddings = paddings;
......
...@@ -123,8 +123,9 @@ class Pool3dOpConverter : public OpConverter { ...@@ -123,8 +123,9 @@ class Pool3dOpConverter : public OpConverter {
nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]); nvinfer1::Dims3 nv_paddings(paddings[0], paddings[1], paddings[2]);
nvinfer1::ILayer *layer = nullptr; nvinfer1::ILayer *layer = nullptr;
if (op_desc.HasAttr("enable_int8")) { if (op_desc.HasAttr("enable_int8")) {
CHECK(op_desc.HasAttr("X_scale")); CHECK(op_desc.HasAttr("Input_scale"));
float input_scale = BOOST_GET_CONST(float, op_desc.GetAttr("X_scale")); float input_scale =
BOOST_GET_CONST(float, op_desc.GetAttr("Input_scale"));
engine_->SetTensorDynamicRange(input1, input_scale); engine_->SetTensorDynamicRange(input1, input_scale);
} }
......
...@@ -70,7 +70,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter { ...@@ -70,7 +70,7 @@ class PrelnEmbEltwiseLayerNormOpConverter : public OpConverter {
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>(); auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims(); (*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
return temp_data; return temp_data;
}; };
......
...@@ -48,7 +48,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter { ...@@ -48,7 +48,7 @@ class PrelnSkipLayerNormOpConverter : public OpConverter {
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>(); auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims(); (*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
return temp_data; return temp_data;
}; };
......
...@@ -57,8 +57,8 @@ class PReluOpConverter : public OpConverter { ...@@ -57,8 +57,8 @@ class PReluOpConverter : public OpConverter {
layer = engine_->AddDynamicPlugin(&input, input_num, plugin); layer = engine_->AddDynamicPlugin(&input, input_num, plugin);
} else { } else {
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
float* alpha_weight_data = engine_->GetWeightCPUData( float* alpha_weight_data =
op_desc.Input("Alpha")[0], alpha_tensor, false); engine_->GetWeightCPUData(op_desc.Input("Alpha")[0], alpha_tensor);
TensorRTEngine::Weight alpha_weight{ TensorRTEngine::Weight alpha_weight{
nvinfer1::DataType::kFLOAT, static_cast<void*>(alpha_weight_data), nvinfer1::DataType::kFLOAT, static_cast<void*>(alpha_weight_data),
static_cast<size_t>(alpha_tensor->numel())}; static_cast<size_t>(alpha_tensor->numel())};
......
...@@ -40,7 +40,7 @@ class SkipLayerNormOpConverter : public OpConverter { ...@@ -40,7 +40,7 @@ class SkipLayerNormOpConverter : public OpConverter {
auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>(); auto* temp_tensor = temp_var->GetMutable<framework::LoDTensor>();
(*dims) = temp_tensor->dims(); (*dims) = temp_tensor->dims();
auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor, false); auto* temp_data = engine_->GetWeightCPUData(var_name, temp_tensor);
return temp_data; return temp_data;
}; };
......
...@@ -356,9 +356,7 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) { ...@@ -356,9 +356,7 @@ void TensorRTEngine::SetRuntimeBatch(size_t batch_size) {
} }
float *TensorRTEngine::GetWeightCPUData(const std::string &name, float *TensorRTEngine::GetWeightCPUData(const std::string &name,
framework::Tensor *weight_tensor, framework::Tensor *weight_tensor) {
bool enable_int8,
const std::vector<float> &scale) {
static int name_suffix_counter = 0; static int name_suffix_counter = 0;
std::string name_suffix = std::to_string(name_suffix_counter); std::string name_suffix = std::to_string(name_suffix_counter);
std::string splitter = "__"; std::string splitter = "__";
......
...@@ -389,8 +389,7 @@ class TensorRTEngine { ...@@ -389,8 +389,7 @@ class TensorRTEngine {
} }
float* GetWeightCPUData(const std::string& name, float* GetWeightCPUData(const std::string& name,
framework::Tensor* weight_tensor, bool enable_int8, framework::Tensor* weight_tensor);
const std::vector<float>& scale = {});
// A pointer to CPU memory is needed of the TRT weight. // A pointer to CPU memory is needed of the TRT weight.
// Before TRT runs, fluid loads weight into GPU storage. // Before TRT runs, fluid loads weight into GPU storage.
......
type: "dequantize_linear"
def {
inputs {
name: "X"
}
inputs {
name: "Scale"
}
inputs {
name: "ZeroPoint"
}
outputs {
name: "Y"
}
attrs {
name: "bit_length"
type: INT
}
attrs {
name: "quant_axis"
type: INT
}
}
extra {
}
...@@ -60,15 +60,7 @@ extra { ...@@ -60,15 +60,7 @@ extra {
type: BOOLEAN type: BOOLEAN
} }
attrs { attrs {
name: "X_scale" name: "Input_scale"
type: FLOAT
}
attrs {
name: "weight_scale"
type: FLOAT
}
attrs {
name: "out_scale"
type: FLOAT type: FLOAT
} }
attrs { attrs {
......
type: "quantize_linear"
def {
inputs {
name: "X"
}
inputs {
name: "Scale"
}
inputs {
name: "ZeroPoint"
}
outputs {
name: "Y"
}
attrs {
name: "bit_length"
type: INT
}
attrs {
name: "quant_axis"
type: INT
}
}
extra {
}
...@@ -491,8 +491,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest): ...@@ -491,8 +491,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
"x_num_col_dims": 2, "x_num_col_dims": 2,
"y_num_col_dims": 1, "y_num_col_dims": 1,
"enable_int8": True, "enable_int8": True,
"X_scale": 1.0, "Input_scale": 1.0,
"weight_scale": [1.0],
}, { }, {
"axis": 2, "axis": 2,
"out_threshold": 1.0, "out_threshold": 1.0,
...@@ -504,8 +503,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest): ...@@ -504,8 +503,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
"x_num_col_dims": 2, "x_num_col_dims": 2,
"y_num_col_dims": 1, "y_num_col_dims": 1,
"enable_int8": True, "enable_int8": True,
"X_scale": 1.0, "Input_scale": 1.0,
"weight_scale": [1.0],
}, { }, {
"axis": 2, "axis": 2,
"out_threshold": 1.0, "out_threshold": 1.0,
...@@ -517,8 +515,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest): ...@@ -517,8 +515,7 @@ class TrtConvertMultiHeadMatmulTestInt8(TrtConvertMultiHeadMatmulTest):
"x_num_col_dims": 2, "x_num_col_dims": 2,
"y_num_col_dims": 1, "y_num_col_dims": 1,
"enable_int8": True, "enable_int8": True,
"X_scale": 1.0, "Input_scale": 1.0,
"weight_scale": [1.0],
}, { }, {
"axis": 2, "axis": 2,
"out_threshold": 1.0, "out_threshold": 1.0,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册