未验证 提交 d7858c99 编写于 作者: W Wangzheee 提交者: GitHub

[PaddleInference] Pass: add int8 flag for op (#36042)

* add_int_pass

* add_int8_flag_pass

* add_int8_flag_pass

* fix CMakeLists.txt

* fix test_trt_fc_fuse_quant_dequant_pass.py

* fix python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_quant_dequant_pass.py

* fix test_trt_fc_fuse_quant_dequant_pass.py
上级 caa2003a
...@@ -97,6 +97,7 @@ pass_library(multihead_matmul_fuse_pass inference) ...@@ -97,6 +97,7 @@ pass_library(multihead_matmul_fuse_pass inference)
pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference) pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
pass_library(generate_pass DEPS pass_desc_proto) pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto)
if(WITH_GPU OR WITH_ROCM) if(WITH_GPU OR WITH_ROCM)
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/add_support_int8_pass.h"
namespace paddle {
namespace framework {
namespace ir {
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(prev_op); \
GET_IR_NODE(prev_out); \
GET_IR_NODE(quant_op); \
GET_IR_NODE(quant_out);
void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const {
const std::string pattern_name = "add_support_int8";
FusePassBase::Init(pattern_name, graph);
GraphPatternDetector gpd;
patterns::AddSupportInt8 pattern(gpd.mutable_pattern(), pattern_name);
pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
GET_NODES;
if (prev_op->Op()->HasAttr("out_threshold") &&
quant_op->Op()->HasAttr("out_threshold")) {
quant_op->Op()->SetAttr("support_int8", true);
}
found_count++;
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(add_support_int8_pass, paddle::framework::ir::AddSupportInt8Pass);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
class AddSupportInt8Pass : public FusePassBase {
public:
AddSupportInt8Pass() {}
virtual ~AddSupportInt8Pass() {}
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
...@@ -2986,6 +2986,29 @@ PDNode *patterns::LayerNorm::operator()() { ...@@ -2986,6 +2986,29 @@ PDNode *patterns::LayerNorm::operator()() {
return shift_out; return shift_out;
} }
// Add support int8 flag
PDNode *patterns::AddSupportInt8::operator()() {
auto prev_op =
pattern->NewNode(prev_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return node->Op()->HasAttr("out_threshold") ? true : false;
});
auto prev_out = pattern->NewNode(prev_out_repr())->assert_is_var();
auto quant_op =
pattern->NewNode(quant_op_repr())
->assert_is_op()
->assert_more([&](Node *node) {
return node->Op()->HasAttr("out_threshold") ? true : false;
});
auto quant_out =
pattern->NewNode(quant_out_repr())->assert_is_var()->AsOutput();
prev_op->LinksTo({prev_out});
prev_out->LinksTo({quant_op});
quant_op->LinksTo({quant_out});
return quant_out;
}
} // namespace ir } // namespace ir
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -1682,6 +1682,18 @@ struct LayerNorm : public PatternBase { ...@@ -1682,6 +1682,18 @@ struct LayerNorm : public PatternBase {
PATTERN_DECL_NODE(shift_out); PATTERN_DECL_NODE(shift_out);
}; };
// Add support int8 flag
struct AddSupportInt8 : public PatternBase {
AddSupportInt8(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "Add_support_int8") {}
PDNode* operator()();
PATTERN_DECL_NODE(prev_op);
PATTERN_DECL_NODE(prev_out);
PATTERN_DECL_NODE(quant_op);
PATTERN_DECL_NODE(quant_out);
};
} // namespace patterns } // namespace patterns
// Link two ir::Nodes from each other. // Link two ir::Nodes from each other.
......
...@@ -96,6 +96,7 @@ const std::vector<std::string> kTRTSubgraphPasses({ ...@@ -96,6 +96,7 @@ const std::vector<std::string> kTRTSubgraphPasses({
"map_matmul_to_mul_pass", // "map_matmul_to_mul_pass", //
"fc_fuse_pass", // "fc_fuse_pass", //
"conv_elementwise_add_fuse_pass", // "conv_elementwise_add_fuse_pass", //
"add_support_int8_pass",
"tensorrt_subgraph_pass", // "tensorrt_subgraph_pass", //
"conv_bn_fuse_pass", // "conv_bn_fuse_pass", //
#if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be
......
...@@ -59,6 +59,8 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -59,6 +59,8 @@ struct SimpleOpTypeSetTeller : public Teller {
#if CUDA_VERSION >= 10020 #if CUDA_VERSION >= 10020
teller_set.insert("reshape"); teller_set.insert("reshape");
teller_set.insert("reshape2"); teller_set.insert("reshape2");
int8_teller_set.insert("reshape");
int8_teller_set.insert("reshape2");
#endif #endif
} }
...@@ -91,7 +93,9 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -91,7 +93,9 @@ struct SimpleOpTypeSetTeller : public Teller {
"scale", "scale",
"elementwise_mul", "elementwise_mul",
"conv2d_transpose", "conv2d_transpose",
"hard_swish"}; "hard_swish",
"transpose",
"transpose2"};
std::unordered_set<std::string> teller_set{"mul", std::unordered_set<std::string> teller_set{"mul",
"matmul", "matmul",
"conv2d", "conv2d",
......
...@@ -86,15 +86,14 @@ class FCQuantDequantFusePassTRTDims3Cols2Test(QuantDequantTest): ...@@ -86,15 +86,14 @@ class FCQuantDequantFusePassTRTDims3Cols2Test(QuantDequantTest):
self.data = fluid.data( self.data = fluid.data(
name='data', shape=[1, 28, 28], dtype='float32') name='data', shape=[1, 28, 28], dtype='float32')
self.label = fluid.data(name='label', shape=[1, 1], dtype='int64') self.label = fluid.data(name='label', shape=[1, 1], dtype='int64')
label_shape = fluid.layers.reshape(self.label, shape=[1, 1, 1])
fc_out = fluid.layers.fc(input=self.data, fc_out = fluid.layers.fc(input=self.data,
size=28, size=28,
num_flatten_dims=2, num_flatten_dims=2,
bias_attr=False, bias_attr=False,
act=None) act=None)
c_out = fluid.layers.reshape(fc_out, shape=[1, 1, 784]) c_out = fluid.layers.reshape(fc_out, shape=[0, 784])
result = fluid.layers.relu(c_out) result = fluid.layers.relu(c_out)
loss = fluid.layers.cross_entropy(input=result, label=label_shape) loss = fluid.layers.cross_entropy(input=result, label=self.label)
avg_loss = fluid.layers.mean(loss) avg_loss = fluid.layers.mean(loss)
return avg_loss, result return avg_loss, result
...@@ -119,11 +118,11 @@ class FCQuantDequantFusePassTRTDims3Cols2Test(QuantDequantTest): ...@@ -119,11 +118,11 @@ class FCQuantDequantFusePassTRTDims3Cols2Test(QuantDequantTest):
self.dynamic_shape_params = FCQuantDequantFusePassTRTDims3Cols2Test.DynamicShapeParam( self.dynamic_shape_params = FCQuantDequantFusePassTRTDims3Cols2Test.DynamicShapeParam(
{ {
'data': [1, 28, 28], 'data': [1, 28, 28],
'reshape2_1.tmp_0': [1, 1, 784] 'reshape2_0.tmp_0': [1, 784]
}, {'data': [4, 28, 28], }, {'data': [4, 28, 28],
'reshape2_1.tmp_0': [4, 1, 784]}, 'reshape2_0.tmp_0':
{'data': [1, 28, 28], [4, 784]}, {'data': [1, 28, 28],
'reshape2_1.tmp_0': [1, 1, 784]}, False) 'reshape2_0.tmp_0': [1, 784]}, False)
self.activation_quantize_type = 'moving_average_abs_max' self.activation_quantize_type = 'moving_average_abs_max'
self.weight_quantize_type = 'channel_wise_abs_max' self.weight_quantize_type = 'channel_wise_abs_max'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册