未验证 提交 38faed7f 编写于 作者: A alncat 提交者: GitHub

Added support for inference using quantization aware trained dygraph (#30288) (#30402)

上级 5d30d072
...@@ -85,6 +85,7 @@ pass_library(runtime_context_cache_pass base) ...@@ -85,6 +85,7 @@ pass_library(runtime_context_cache_pass base)
pass_library(quant_conv2d_dequant_fuse_pass inference) pass_library(quant_conv2d_dequant_fuse_pass inference)
pass_library(shuffle_channel_detect_pass inference) pass_library(shuffle_channel_detect_pass inference)
pass_library(delete_quant_dequant_op_pass inference) pass_library(delete_quant_dequant_op_pass inference)
pass_library(delete_quant_dequant_filter_op_pass inference)
pass_library(simplify_with_basic_ops_pass base) pass_library(simplify_with_basic_ops_pass base)
pass_library(fc_elementwise_layernorm_fuse_pass base) pass_library(fc_elementwise_layernorm_fuse_pass base)
pass_library(skip_layernorm_fuse_pass base) pass_library(skip_layernorm_fuse_pass base)
......
...@@ -62,6 +62,14 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -62,6 +62,14 @@ void ConvElementwiseAddFusePass::ApplyImpl(ir::Graph* graph) const {
new_op_desc.SetOutput("Output", {output_name}); new_op_desc.SetOutput("Output", {output_name});
new_op_desc.SetAttr("is_test", true); new_op_desc.SetAttr("is_test", true);
new_op_desc.SetAttr("use_cudnn", false); new_op_desc.SetAttr("use_cudnn", false);
auto* elementwise_add_op_desc = elementwise_add_op->Op();
auto out_threshold_attr =
elementwise_add_op_desc->GetNullableAttr("out_threshold");
// set the out_threshold of the elementwise add op to be the out_threshold
// of the conv2d_fusion
if (out_threshold_attr.which()) {
new_op_desc.SetAttr("out_threshold", out_threshold_attr);
}
new_op_desc.Flush(); new_op_desc.Flush();
// Create a new node for the fused op. // Create a new node for the fused op.
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/delete_quant_dequant_filter_op_pass.h"
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(quant_dequant_op_x); \
GET_IR_NODE(quant_dequant_op); \
GET_IR_NODE(quant_dequant_op_out); \
GET_IR_NODE(quant_dequant_op_outscale); \
GET_IR_NODE(any_op2);
// Delete quant_dequant_op, then quantize and dequantize weight
void DeleteQuantDequantFilterOpPass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "delete_quantdequant_filter_op_pattern";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
// Create pattern
patterns::DeleteQuantDequantFilterOpPattern pattern(gpd.mutable_pattern(),
pattern_name);
pattern();
auto* scope = param_scope();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
std::unordered_set<const Node*> nodes2rm = {};
int bit_length =
BOOST_GET_CONST(int, quant_dequant_op->Op()->GetAttr("bit_length"));
int range = ((1 << (bit_length - 1)) - 1);
std::vector<float> weight_scale;
std::string quant_dequant_op_out_name = quant_dequant_op_out->Var()->Name();
auto* any_op2_desc = any_op2->Op();
auto var_map = any_op2_desc->Inputs();
std::string arg_name = "";
for (auto& name_m : var_map) {
if (std::find(name_m.second.begin(), name_m.second.end(),
quant_dequant_op_out_name) != name_m.second.end()) {
arg_name = name_m.first;
break;
}
}
PADDLE_ENFORCE_GT(arg_name.size(), 0, platform::errors::InvalidArgument(
"can not find the input %s.",
quant_dequant_op_out_name));
any_op2_desc->SetAttr("enable_int8", true);
any_op2_desc->SetAttr("bit_length", bit_length);
// modify the any_op2's inputs
any_op2_desc->Flush();
auto dequant_type = quant_dequant_op->Op()->Type();
auto quantized_op_type = any_op2_desc->Type();
// Get weight scale
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
auto scales_name = quant_dequant_op->Op()->Output("OutScale");
PADDLE_ENFORCE_EQ(scales_name.size(), 1,
platform::errors::InvalidArgument(
"Scales size in channel-wise quant dequantize op "
"should be 1, got %d.",
scales_name.size()));
const LoDTensor& channel_scale_tensor =
scope->GetVar(scales_name[0])->Get<LoDTensor>();
PADDLE_ENFORCE(
paddle::platform::is_cpu_place(channel_scale_tensor.place()),
platform::errors::InvalidArgument(
"Channel scale tensor's place should be CPU."));
const float* channel_scale_data = channel_scale_tensor.data<float>();
for (int i = 0; i < channel_scale_tensor.numel(); i++) {
weight_scale.push_back(range / channel_scale_data[i]);
}
} else {
auto scale_name = quant_dequant_op_outscale->Name();
const LoDTensor& scale_tensor =
scope->GetVar(scale_name)->Get<LoDTensor>();
const float* scale_data = scale_tensor.data<float>();
weight_scale.push_back((range * range) / scale_data[0] / range);
}
nodes2rm.insert(quant_dequant_op_outscale);
// perform quantize dequantize operations
auto* weight_tensor =
scope->GetVar(quant_dequant_op_x->Name())->GetMutable<LoDTensor>();
auto w_dims = weight_tensor->dims();
float* quantized_weight_data =
weight_tensor->mutable_data<float>(platform::CPUPlace());
// If quantized op is fc, weight scale size = 1;
// If quantized op is conv2d, weight scale size = weight dims[0]
// If quantized op is conv2d_transpose, weight scale size = weight dims[1]
if (dequant_type == "fake_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), 1,
platform::errors::InvalidArgument(
"%s op weight dequantized by [fake_quantize_dequantize_max_abs] "
"requires weight scale size = 1, but got %d.",
quantized_op_type, weight_scale.size()));
PADDLE_ENFORCE_NE(weight_scale[0], 0,
platform::errors::InvalidArgument(
"Weight scale should be nonzero, but get zero"));
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
quantized_weight_data[j] = quantized_weight_data[j] * weight_scale[0];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[0];
}
} else if (quantized_op_type == "mul" || quantized_op_type == "matmul" ||
quantized_op_type == "fc") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[1]),
platform::errors::InvalidArgument(
"mul op weight dequantized by "
"[fake_channel_wise_quantize_dequantize_abs_max] requires "
"weight scale "
"size = 2nd dim of mul's weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(
weight_scale[j % w_dims[1]], 0,
platform::errors::InvalidArgument(
"fc op weight scale should be nonzero, but get zero"));
quantized_weight_data[j] =
quantized_weight_data[j] * weight_scale[j % w_dims[1]];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[j % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else if (quantized_op_type == "conv2d" ||
quantized_op_type == "depthwise_conv2d") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d op requires weight scale size = channel size of the "
"weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[0]), weight_scale.size()));
int inner_size = w_dims[1] * w_dims[2] * w_dims[3];
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(
weight_scale[j / inner_size], 0,
platform::errors::InvalidArgument(
"conv2d op weight scale should be nonzero, but get zero"));
quantized_weight_data[j] =
quantized_weight_data[j] * weight_scale[j / inner_size];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /= weight_scale[j / inner_size];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else if (quantized_op_type == "conv2d_transpose") {
if (dequant_type == "fake_channel_wise_quantize_dequantize_abs_max") {
PADDLE_ENFORCE_EQ(
weight_scale.size(), static_cast<size_t>(w_dims[0]),
platform::errors::InvalidArgument(
"conv2d_transpose op requires weight scale size = channel size "
"of the "
"weight, which is %zu, but got %zu.",
static_cast<size_t>(w_dims[1]), weight_scale.size()));
int inner_size = w_dims[2] * w_dims[3];
for (int j = 0; j < weight_tensor->numel(); j++) {
// quantized
PADDLE_ENFORCE_NE(weight_scale[(j / inner_size) % w_dims[1]], 0,
platform::errors::InvalidArgument(
"conv2d_transpose op weight scale should be "
"nonzero, but get zero"));
quantized_weight_data[j] = quantized_weight_data[j] *
weight_scale[(j / inner_size) % w_dims[1]];
quantized_weight_data[j] = std::round(quantized_weight_data[j]);
// dequantized
quantized_weight_data[j] /=
weight_scale[(j / inner_size) % w_dims[1]];
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported quantized op type: %s", quantized_op_type));
}
nodes2rm.insert(quant_dequant_op_out);
// link weight in quant_dequant_op_x to any_op2
any_op2_desc->RenameInput(quant_dequant_op_out->Var()->Name(),
quant_dequant_op_x->Var()->Name());
any_op2_desc->SetAttr("weight_scale", weight_scale);
any_op2_desc->Flush();
IR_NODE_LINK_TO(quant_dequant_op_x, any_op2);
nodes2rm.insert(quant_dequant_op);
GraphSafeRemoveNodes(graph, nodes2rm);
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(delete_quant_dequant_filter_op_pass,
paddle::framework::ir::DeleteQuantDequantFilterOpPass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class DeleteQuantDequantFilterOpPass : public FusePassBase {
public:
virtual ~DeleteQuantDequantFilterOpPass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -49,10 +49,10 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const { ...@@ -49,10 +49,10 @@ void DeleteQuantDequantOpPass::ApplyImpl(ir::Graph* graph) const {
std::string input_scale_var_name = std::string input_scale_var_name =
quant_dequant_op->Op()->Input("InScale").front(); quant_dequant_op->Op()->Input("InScale").front();
const LoDTensor& input_scale_tensor = const LoDTensor& input_scale_tensor =
scope->FindVar(input_scale_var_name)->Get<LoDTensor>(); scope->GetVar(input_scale_var_name)->Get<LoDTensor>();
const float* input_scale_data = input_scale_tensor.data<float>(); const float* input_scale_data = input_scale_tensor.data<float>();
float input_scale = input_scale_data[0]; float input_scale = input_scale_data[0] / 127.;
auto* any_op2_desc = any_op2->Op(); auto* any_op2_desc = any_op2->Op();
// auto input_args_names = any_op2_desc->InputArgumentNames(); // auto input_args_names = any_op2_desc->InputArgumentNames();
auto var_map = any_op2_desc->Inputs(); auto var_map = any_op2_desc->Inputs();
......
...@@ -149,6 +149,18 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const { ...@@ -149,6 +149,18 @@ int FCFusePass::ApplyFCPattern(Graph* graph, bool with_relu) const {
desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale")); desc.SetAttr("out_scale", elementwise_desc->GetAttr("out_scale"));
} }
auto* elementwise_add_op_desc = elementwise_add->Op();
// if we can find out_threshold in elementwise_add, then set it as the
// out_thrshold of fc
auto out_threshold_attr =
elementwise_add_op_desc->GetNullableAttr("out_threshold");
if (out_threshold_attr.which()) {
VLOG(4) << "setting out_threshold: "
<< BOOST_GET_CONST(float, out_threshold_attr);
desc.SetAttr("out_threshold", out_threshold_attr);
}
desc.Flush();
auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied. auto fc_node = g->CreateOpNode(&desc); // OpDesc will be copied.
if (with_relu) { if (with_relu) {
GraphSafeRemoveNodes( GraphSafeRemoveNodes(
......
...@@ -1634,6 +1634,27 @@ PDNode *patterns::MatmulWithInputOps::operator()() { ...@@ -1634,6 +1634,27 @@ PDNode *patterns::MatmulWithInputOps::operator()() {
return matmul_out; return matmul_out;
} }
PDNode *patterns::Flatten2Matmul::operator()() {
auto flatten2_in_x = pattern->NewNode(flatten2_in_x_repr())
->assert_is_op_input("flatten2", "X")
->AsInput();
auto flatten2_op =
pattern->NewNode(flatten2_op_repr())->assert_is_op("flatten2");
auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->assert_is_op_output("flatten2", "Out")
->assert_is_op_input("matmul", "X");
auto matmul_in_y =
pattern->NewNode(matmul_in_y_repr())->assert_is_op_input("matmul", "Y");
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_out = pattern->NewNode(matmul_out_repr())
->AsOutput()
->assert_is_op_output("matmul", "Out");
flatten2_op->LinksFrom({flatten2_in_x}).LinksTo({matmul_in_x});
matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({matmul_out});
return matmul_out;
}
PDNode *patterns::ConvResidual::operator()(bool with_residual_data) { PDNode *patterns::ConvResidual::operator()(bool with_residual_data) {
auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d"); auto conv_op = pattern->NewNode(conv_op_repr())->assert_is_op("conv2d");
...@@ -2495,6 +2516,43 @@ void patterns::DeleteQuantDequantOpPattern::operator()() { ...@@ -2495,6 +2516,43 @@ void patterns::DeleteQuantDequantOpPattern::operator()() {
any_op2->LinksFrom({quant_dequant_out}); any_op2->LinksFrom({quant_dequant_out});
} }
void patterns::DeleteQuantDequantFilterOpPattern::operator()() {
auto quant_dequant_op_x =
pattern->NewNode(quant_dequant_op_x_repr())
->assert_is_ops_input(
{"fake_channel_wise_quantize_dequantize_abs_max",
"fake_quantize_dequantize_abs_max"},
"X")
->AsInput();
auto quant_dequant_op =
pattern->NewNode(quant_dequant_op_repr())
->assert_is_ops({"fake_channel_wise_quantize_dequantize_abs_max",
"fake_quantize_dequantize_abs_max"});
auto quant_dequant_out =
pattern->NewNode(quant_dequant_op_out_repr())
->assert_is_ops_output(
{"fake_channel_wise_quantize_dequantize_abs_max",
"fake_quantize_dequantize_abs_max"},
"Out")
->AsIntermediate();
auto quant_dequant_op_outscale =
pattern->NewNode(quant_dequant_op_outscale_repr())
->assert_is_ops_output(
{"fake_channel_wise_quantize_dequantize_abs_max",
"fake_quantize_dequantize_abs_max"},
"OutScale")
->AsOutput();
auto any_op2 = pattern->NewNode(any_op2_repr())->assert_is_op()->AsOutput();
quant_dequant_op->LinksFrom({quant_dequant_op_x});
quant_dequant_op_outscale->LinksFrom({quant_dequant_op});
quant_dequant_out->LinksFrom({quant_dequant_op});
any_op2->LinksFrom({quant_dequant_out});
}
PDNode *patterns::ReshapeTransposeMatmulPattern::operator()( PDNode *patterns::ReshapeTransposeMatmulPattern::operator()(
bool with_reshape_xshape, bool with_transpose_xshape) { bool with_reshape_xshape, bool with_transpose_xshape) {
auto reshape_op = auto reshape_op =
......
...@@ -996,6 +996,21 @@ struct MatmulWithInputOps : public PatternBase { ...@@ -996,6 +996,21 @@ struct MatmulWithInputOps : public PatternBase {
PATTERN_DECL_NODE(matmul_out); PATTERN_DECL_NODE(matmul_out);
}; };
// Flatten2 + Matmul
// Forward pass.
struct Flatten2Matmul : public PatternBase {
Flatten2Matmul(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "flatten2_matmul") {}
PDNode* operator()();
PATTERN_DECL_NODE(flatten2_in_x);
PATTERN_DECL_NODE(flatten2_op);
PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(matmul_out);
};
// Concat op // Concat op
// Forward pass for concat. // Forward pass for concat.
// concat_out is a result of the operator. // concat_out is a result of the operator.
...@@ -1426,6 +1441,21 @@ struct DeleteQuantDequantOpPattern : public PatternBase { ...@@ -1426,6 +1441,21 @@ struct DeleteQuantDequantOpPattern : public PatternBase {
PATTERN_DECL_NODE(any_op2); PATTERN_DECL_NODE(any_op2);
}; };
struct DeleteQuantDequantFilterOpPattern : public PatternBase {
DeleteQuantDequantFilterOpPattern(PDPattern* pattern,
const std::string& name_scope)
: PatternBase(pattern, name_scope,
"delete_quantdequant_filter_op_pattern") {}
void operator()();
PATTERN_DECL_NODE(quant_dequant_op_x);
PATTERN_DECL_NODE(quant_dequant_op);
PATTERN_DECL_NODE(quant_dequant_op_outscale);
PATTERN_DECL_NODE(quant_dequant_op_out);
PATTERN_DECL_NODE(any_op2);
};
// Reshape + Transpose + Matmul // Reshape + Transpose + Matmul
// named nodes: // named nodes:
// reshape_op, reshape_out, reshape_xshape, // reshape_op, reshape_out, reshape_xshape,
......
...@@ -71,7 +71,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const { ...@@ -71,7 +71,11 @@ void MapMatmul2MulPass::ApplyImpl(ir::Graph* graph) const {
desc.SetOutput("Out", {matmul_out->Name()}); desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1); desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
}
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(matmul_in_x, mul_node); IR_NODE_LINK_TO(matmul_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node);
...@@ -137,7 +141,11 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -137,7 +141,11 @@ void Squeeze2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetOutput("Out", {matmul_out->Name()}); desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1); desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
}
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(squeeze2_in_x, mul_node); IR_NODE_LINK_TO(squeeze2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node);
...@@ -205,7 +213,11 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -205,7 +213,11 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetOutput("Out", {matmul_out->Name()}); desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", 1); desc.SetAttr("x_num_col_dims", 1);
desc.SetAttr("y_num_col_dims", 1); desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
}
auto mul_node = g->CreateOpNode(&desc); auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(reshape2_in_x, mul_node); IR_NODE_LINK_TO(reshape2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node); IR_NODE_LINK_TO(matmul_in_y, mul_node);
...@@ -219,6 +231,83 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -219,6 +231,83 @@ void Reshape2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
AddStatis(found_count); AddStatis(found_count);
} }
void Flatten2MatmulFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "flatten2_matmul_fuse_pass";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::Flatten2Matmul fuse_pattern(gpd.mutable_pattern(), name_scope);
fuse_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "fuse flatten2+matmul to mul";
GET_IR_NODE_FROM_SUBGRAPH(flatten2_in_x, flatten2_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(flatten2_op, flatten2_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, fuse_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_out, matmul_out, fuse_pattern);
bool pattern_found = true;
size_t flatten2_in_nums = flatten2_op->inputs.size();
auto flatten2_in_x_shape = flatten2_in_x->Var()->GetShape();
size_t flatten2_in_x_rank = flatten2_in_x_shape.size();
int flatten2_axis =
BOOST_GET_CONST(int, flatten2_op->Op()->GetAttr("axis"));
// only convert matmul to mul when the flatten2 has a single input
// and the rank of input is 4 and the size of the output of matmul
// is 1.
pattern_found = pattern_found && flatten2_in_nums == 1 &&
flatten2_in_x_rank == 4 &&
(matmul_in_x->outputs).size() == 1;
bool transpose_X =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_X"));
bool transpose_Y =
BOOST_GET_CONST(bool, matmul_op->Op()->GetAttr("transpose_Y"));
float alpha = BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha"));
size_t matmul_in_x_rank = (matmul_in_x->Var()->GetShape()).size();
size_t matmul_in_y_rank = (matmul_in_y->Var()->GetShape()).size();
pattern_found = pattern_found && !transpose_X && !transpose_Y &&
std::abs(alpha - 1.0) < 1e-5 && matmul_in_x_rank == 2 &&
matmul_in_y_rank == 2;
std::vector<Node*>& next_ops = matmul_out->outputs;
// we further require the matmul op is followed by one elementwise
// add op.
pattern_found = pattern_found && next_ops.size() == 1 &&
next_ops[0]->Name() == "elementwise_add";
if (pattern_found) {
OpDesc desc;
desc.SetType("mul");
desc.SetInput("X", {flatten2_in_x->Name()});
desc.SetInput("Y", {matmul_in_y->Name()});
desc.SetOutput("Out", {matmul_out->Name()});
desc.SetAttr("x_num_col_dims", flatten2_axis);
desc.SetAttr("y_num_col_dims", 1);
if (matmul_op->Op()->HasAttr("enable_int8")) {
desc.SetAttr("enable_int8", matmul_op->Op()->GetAttr("enable_int8"));
desc.SetAttr("X_scale", matmul_op->Op()->GetAttr("X_scale"));
desc.SetAttr("weight_scale", matmul_op->Op()->GetAttr("weight_scale"));
}
auto mul_node = g->CreateOpNode(&desc);
IR_NODE_LINK_TO(flatten2_in_x, mul_node);
IR_NODE_LINK_TO(matmul_in_y, mul_node);
IR_NODE_LINK_TO(mul_node, matmul_out);
GraphSafeRemoveNodes(graph, {flatten2_op, matmul_in_x, matmul_op});
++found_count;
}
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -247,3 +336,12 @@ REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass) ...@@ -247,3 +336,12 @@ REGISTER_PASS_CAPABILITY(reshape2_matmul_fuse_pass)
.LE("matmul", 1) .LE("matmul", 1)
.EQ("reshape2", 0) .EQ("reshape2", 0)
.EQ("mul", 0)); .EQ("mul", 0));
REGISTER_PASS(flatten2_matmul_fuse_pass,
paddle::framework::ir::Flatten2MatmulFusePass);
REGISTER_PASS_CAPABILITY(flatten2_matmul_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("flatten2", 0)
.EQ("mul", 0));
...@@ -101,6 +101,14 @@ class Reshape2MatmulFusePass : public FusePassBase { ...@@ -101,6 +101,14 @@ class Reshape2MatmulFusePass : public FusePassBase {
void ApplyImpl(Graph* graph) const override; void ApplyImpl(Graph* graph) const override;
}; };
class Flatten2MatmulFusePass : public FusePassBase {
public:
virtual ~Flatten2MatmulFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -83,6 +83,13 @@ Variable* Scope::FindVar(const std::string& name) const { ...@@ -83,6 +83,13 @@ Variable* Scope::FindVar(const std::string& name) const {
return FindVarInternal(name); return FindVarInternal(name);
} }
Variable* Scope::GetVar(const std::string& name) const {
auto* var = FindVar(name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("Cannot find %s in scope.", name));
return var;
}
Variable* Scope::FindLocalVar(const std::string& name) const { Variable* Scope::FindLocalVar(const std::string& name) const {
SCOPE_VARS_READER_LOCK SCOPE_VARS_READER_LOCK
return FindVarLocally(name); return FindVarLocally(name);
......
...@@ -81,6 +81,10 @@ class Scope { ...@@ -81,6 +81,10 @@ class Scope {
/// Caller doesn't own the returned Variable. /// Caller doesn't own the returned Variable.
Variable* FindVar(const std::string& name) const; Variable* FindVar(const std::string& name) const;
// Get a variable in the scope or any of its ancestors. Enforce
/// the returned Variable is not nullptr
Variable* GetVar(const std::string& name) const;
/// Find a variable in the current scope. /// Find a variable in the current scope.
/// Return nullptr if cannot find. /// Return nullptr if cannot find.
/// Caller doesn't own the returned Variable. /// Caller doesn't own the returned Variable.
......
...@@ -345,7 +345,7 @@ void AnalysisConfig::Update() { ...@@ -345,7 +345,7 @@ void AnalysisConfig::Update() {
pass_builder()->ClearPasses(); pass_builder()->ClearPasses();
for (const auto &pass : kTRTSubgraphPasses) { for (const auto &pass : kTRTSubgraphPasses) {
if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 && if (tensorrt_precision_mode_ == AnalysisConfig::Precision::kInt8 &&
(pass == "conv_bn_fuse_pass" || pass == "fc_fuse_pass")) { (pass == "conv_bn_fuse_pass")) {
continue; continue;
} }
pass_builder()->AppendPass(pass); pass_builder()->AppendPass(pass);
......
...@@ -77,6 +77,7 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -77,6 +77,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"shuffle_channel_detect_pass", // "shuffle_channel_detect_pass", //
"quant_conv2d_dequant_fuse_pass", // "quant_conv2d_dequant_fuse_pass", //
"delete_quant_dequant_op_pass", // "delete_quant_dequant_op_pass", //
"delete_quant_dequant_filter_op_pass", //
// "fc_fuse_pass", // // "fc_fuse_pass", //
"simplify_with_basic_ops_pass", // "simplify_with_basic_ops_pass", //
"embedding_eltwise_layernorm_fuse_pass", // "embedding_eltwise_layernorm_fuse_pass", //
...@@ -86,15 +87,16 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -86,15 +87,16 @@ const std::vector<std::string> kTRTSubgraphPasses({
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
"squeeze2_matmul_fuse_pass", // "squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", // "map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
"tensorrt_subgraph_pass", // "tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
// guaranteed at least v7 // guaranteed at least v7
"conv_elementwise_add_act_fuse_pass", // "conv_elementwise_add_act_fuse_pass", //
"conv_elementwise_add2_act_fuse_pass", // "conv_elementwise_add2_act_fuse_pass", //
"conv_elementwise_add_fuse_pass", //
#endif // #endif //
"transpose_flatten_concat_fuse_pass", "transpose_flatten_concat_fuse_pass",
}); });
...@@ -118,6 +120,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) { ...@@ -118,6 +120,7 @@ GpuPassStrategy::GpuPassStrategy() : PassStrategy({}) {
"multihead_matmul_fuse_pass_v2", // "multihead_matmul_fuse_pass_v2", //
"squeeze2_matmul_fuse_pass", // "squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", // "map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"fc_elementwise_layernorm_fuse_pass", // "fc_elementwise_layernorm_fuse_pass", //
...@@ -172,6 +175,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) { ...@@ -172,6 +175,7 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"seq_concat_fc_fuse_pass", // "seq_concat_fc_fuse_pass", //
"squeeze2_matmul_fuse_pass", // "squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", // "reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"map_matmul_to_mul_pass", // "map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", // "repeated_fc_relu_fuse_pass", //
......
...@@ -105,8 +105,18 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op, ...@@ -105,8 +105,18 @@ void ConvertConv2d(TensorRTEngine* engine, const framework::proto::OpDesc& op,
TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT, TensorRTEngine::Weight weight{nvinfer1::DataType::kFLOAT,
static_cast<void*>(weight_data), static_cast<void*>(weight_data),
static_cast<size_t>(Y_t->numel())}; static_cast<size_t>(Y_t->numel())};
float* bias_data = nullptr;
size_t bias_size = 0;
if (op_desc.Type() == "conv2d_fusion") {
auto* bias_tensor = scope.GetVar(op_desc.Input("Bias").front());
auto* bias_tensor_data = bias_tensor->GetMutable<framework::LoDTensor>();
bias_data = engine->GetWeightCPUData(op_desc.Input("Bias").front(),
bias_tensor_data, false);
bias_size = static_cast<size_t>(bias_tensor_data->numel());
}
TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT, nullptr, 0}; TensorRTEngine::Weight bias{nvinfer1::DataType::kFLOAT,
static_cast<void*>(bias_data), bias_size};
auto* layer = fadd_layer(const_cast<nvinfer1::ITensor*>(X), n_output, n_input, auto* layer = fadd_layer(const_cast<nvinfer1::ITensor*>(X), n_output, n_input,
nv_ksize, weight, bias); nv_ksize, weight, bias);
PADDLE_ENFORCE_NOT_NULL(layer, PADDLE_ENFORCE_NOT_NULL(layer,
...@@ -184,4 +194,5 @@ class Deconv2dOpConverter : public OpConverter { ...@@ -184,4 +194,5 @@ class Deconv2dOpConverter : public OpConverter {
} // namespace paddle } // namespace paddle
REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter); REGISTER_TRT_OP_CONVERTER(conv2d, Conv2dOpConverter);
REGISTER_TRT_OP_CONVERTER(conv2d_fusion, Conv2dOpConverter);
REGISTER_TRT_OP_CONVERTER(conv2d_transpose, Deconv2dOpConverter); REGISTER_TRT_OP_CONVERTER(conv2d_transpose, Deconv2dOpConverter);
...@@ -67,10 +67,11 @@ class FcOpConverter : public OpConverter { ...@@ -67,10 +67,11 @@ class FcOpConverter : public OpConverter {
// assigned from CPU memory, which can't be avoided. // assigned from CPU memory, which can't be avoided.
float* weight_data = nullptr; float* weight_data = nullptr;
bool enable_int8 = op_desc.HasAttr("enable_int8"); bool enable_int8 = op_desc.HasAttr("enable_int8");
float in_scale = 0.;
if (enable_int8) { if (enable_int8) {
#if IS_TRT_VERSION_GE(5000) #if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr(i_name + "_scale")); CHECK(op_desc.HasAttr(i_name + "_scale"));
float in_scale = in_scale =
BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")) * 127; BOOST_GET_CONST(float, op_desc.GetAttr(i_name + "_scale")) * 127;
auto weight_scale = auto weight_scale =
BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale")); BOOST_GET_CONST(std::vector<float>, op_desc.GetAttr("weight_scale"));
...@@ -131,7 +132,7 @@ class FcOpConverter : public OpConverter { ...@@ -131,7 +132,7 @@ class FcOpConverter : public OpConverter {
float* bias_data = nullptr; float* bias_data = nullptr;
int bias_num = 0; int bias_num = 0;
if (with_bias) { if (with_bias) {
auto* b_v = scope.FindVar(op_desc.Input("Bias").front()); auto* b_v = scope.GetVar(op_desc.Input("Bias").front());
auto* b_t = b_v->GetMutable<framework::LoDTensor>(); auto* b_t = b_v->GetMutable<framework::LoDTensor>();
bias_data = bias_data =
engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t, false); engine_->GetWeightCPUData(op_desc.Input("Bias").front(), b_t, false);
...@@ -183,6 +184,9 @@ class FcOpConverter : public OpConverter { ...@@ -183,6 +184,9 @@ class FcOpConverter : public OpConverter {
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X); auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X);
reshape_layer->setReshapeDimensions(reshape_dim); reshape_layer->setReshapeDimensions(reshape_dim);
reshape_itensor = reshape_layer->getOutput(0); reshape_itensor = reshape_layer->getOutput(0);
if (enable_int8) {
engine_->SetTensorDynamicRange(reshape_itensor, in_scale);
}
} else { } else {
PADDLE_ENFORCE_NE(input_dims, 1, PADDLE_ENFORCE_NE(input_dims, 1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -200,6 +204,9 @@ class FcOpConverter : public OpConverter { ...@@ -200,6 +204,9 @@ class FcOpConverter : public OpConverter {
auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X); auto* reshape_layer = TRT_ENGINE_ADD_LAYER(engine_, Shuffle, *X);
reshape_layer->setReshapeDimensions(reshape_dim); reshape_layer->setReshapeDimensions(reshape_dim);
reshape_itensor = reshape_layer->getOutput(0); reshape_itensor = reshape_layer->getOutput(0);
if (enable_int8) {
engine_->SetTensorDynamicRange(reshape_itensor, in_scale);
}
} }
regist_fc(reshape_itensor, n_output, weight, bias); regist_fc(reshape_itensor, n_output, weight, bias);
} }
......
...@@ -58,6 +58,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -58,6 +58,7 @@ struct SimpleOpTypeSetTeller : public Teller {
// use this set for no calib int8. // use this set for no calib int8.
std::unordered_set<std::string> int8_teller_set{"mul", std::unordered_set<std::string> int8_teller_set{"mul",
"conv2d", "conv2d",
"conv2d_fusion",
"pool2d", "pool2d",
"relu", "relu",
"depthwise_conv2d", "depthwise_conv2d",
...@@ -76,6 +77,7 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -76,6 +77,7 @@ struct SimpleOpTypeSetTeller : public Teller {
"mul", "mul",
"matmul", "matmul",
"conv2d", "conv2d",
"conv2d_fusion",
"pool2d", "pool2d",
"relu", "relu",
"softmax", "softmax",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册