diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 6f5f27400752dd9edf679a1ae249e77ed9fbbe89..a2e9fc3a3d9ac53b1cb2f3fc105dfd0c0e00b860 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -97,6 +97,7 @@ pass_library(multihead_matmul_fuse_pass inference) pass_library(adaptive_pool2d_convert_global_pass inference) pass_library(unsqueeze2_eltwise_fuse_pass inference) pass_library(layer_norm_fuse_pass inference) +pass_library(add_support_int8_pass inference) pass_library(generate_pass DEPS pass_desc_proto) target_link_libraries(generate_pass pass_desc_proto) if(WITH_GPU OR WITH_ROCM) diff --git a/paddle/fluid/framework/ir/add_support_int8_pass.cc b/paddle/fluid/framework/ir/add_support_int8_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..d157d2e934acea6657e5b5b6ce6a9bd53aedc1f4 --- /dev/null +++ b/paddle/fluid/framework/ir/add_support_int8_pass.cc @@ -0,0 +1,54 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/add_support_int8_pass.h" + +namespace paddle { +namespace framework { +namespace ir { + +#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern); +#define GET_NODES \ + GET_IR_NODE(prev_op); \ + GET_IR_NODE(prev_out); \ + GET_IR_NODE(quant_op); \ + GET_IR_NODE(quant_out); + +void AddSupportInt8Pass::ApplyImpl(ir::Graph* graph) const { + const std::string pattern_name = "add_support_int8"; + FusePassBase::Init(pattern_name, graph); + + GraphPatternDetector gpd; + + patterns::AddSupportInt8 pattern(gpd.mutable_pattern(), pattern_name); + pattern(); + int found_count = 0; + auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph, + Graph* g) { + GET_NODES; + if (prev_op->Op()->HasAttr("out_threshold") && + quant_op->Op()->HasAttr("out_threshold")) { + quant_op->Op()->SetAttr("support_int8", true); + } + found_count++; + }; + gpd(graph, handler); + AddStatis(found_count); +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(add_support_int8_pass, paddle::framework::ir::AddSupportInt8Pass); diff --git a/paddle/fluid/framework/ir/add_support_int8_pass.h b/paddle/fluid/framework/ir/add_support_int8_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..372250d60169d3c0fa109424871c8b5b1d532940 --- /dev/null +++ b/paddle/fluid/framework/ir/add_support_int8_pass.h @@ -0,0 +1,36 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" + +namespace paddle { +namespace framework { +namespace ir { + +class Graph; + +class AddSupportInt8Pass : public FusePassBase { + public: + AddSupportInt8Pass() {} + virtual ~AddSupportInt8Pass() {} + + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.cc b/paddle/fluid/framework/ir/graph_pattern_detector.cc index 449849762cb10190f5eedffdc2206e8e2e933999..695da372d18f3e2bc68643cd5519db2809de6bc9 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.cc +++ b/paddle/fluid/framework/ir/graph_pattern_detector.cc @@ -2986,6 +2986,29 @@ PDNode *patterns::LayerNorm::operator()() { return shift_out; } +// Add support int8 flag +PDNode *patterns::AddSupportInt8::operator()() { + auto prev_op = + pattern->NewNode(prev_op_repr()) + ->assert_is_op() + ->assert_more([&](Node *node) { + return node->Op()->HasAttr("out_threshold") ? true : false; + }); + auto prev_out = pattern->NewNode(prev_out_repr())->assert_is_var(); + auto quant_op = + pattern->NewNode(quant_op_repr()) + ->assert_is_op() + ->assert_more([&](Node *node) { + return node->Op()->HasAttr("out_threshold") ? true : false; + }); + auto quant_out = + pattern->NewNode(quant_out_repr())->assert_is_var()->AsOutput(); + prev_op->LinksTo({prev_out}); + prev_out->LinksTo({quant_op}); + quant_op->LinksTo({quant_out}); + return quant_out; +} + } // namespace ir } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/ir/graph_pattern_detector.h b/paddle/fluid/framework/ir/graph_pattern_detector.h index 40c3e4f59bf262ea260a3e9a784d9bc73696ed80..4afb7dfd4991b0ef1439594f4ceb18d5b5cef19b 100644 --- a/paddle/fluid/framework/ir/graph_pattern_detector.h +++ b/paddle/fluid/framework/ir/graph_pattern_detector.h @@ -1682,6 +1682,18 @@ struct LayerNorm : public PatternBase { PATTERN_DECL_NODE(shift_out); }; +// Add support int8 flag +struct AddSupportInt8 : public PatternBase { + AddSupportInt8(PDPattern* pattern, const std::string& name_scope) + : PatternBase(pattern, name_scope, "Add_support_int8") {} + + PDNode* operator()(); + PATTERN_DECL_NODE(prev_op); + PATTERN_DECL_NODE(prev_out); + PATTERN_DECL_NODE(quant_op); + PATTERN_DECL_NODE(quant_out); +}; + } // namespace patterns // Link two ir::Nodes from each other. diff --git a/paddle/fluid/inference/api/paddle_pass_builder.cc b/paddle/fluid/inference/api/paddle_pass_builder.cc index 704fbb2b95c8929fdb8c76072c804340b3c0fe08..47e9c1fd202a05cdb5f93108358c25b270b71c73 100644 --- a/paddle/fluid/inference/api/paddle_pass_builder.cc +++ b/paddle/fluid/inference/api/paddle_pass_builder.cc @@ -96,8 +96,9 @@ const std::vector kTRTSubgraphPasses({ "map_matmul_to_mul_pass", // "fc_fuse_pass", // "conv_elementwise_add_fuse_pass", // - "tensorrt_subgraph_pass", // - "conv_bn_fuse_pass", // + "add_support_int8_pass", + "tensorrt_subgraph_pass", // + "conv_bn_fuse_pass", // #if CUDNN_VERSION >= 7100 // To run conv_fusion, the version of cudnn must be // guaranteed at least v7 // cudnn8.0 has memory leak problem in conv + eltwise + act, so we diff --git a/paddle/fluid/inference/tensorrt/op_teller.cc b/paddle/fluid/inference/tensorrt/op_teller.cc index ef50aee48e2eb84ed7950793c3d233250cf07ada..59368a299c59e24afe4f563cd45f11e86f7dbe5d 100644 --- a/paddle/fluid/inference/tensorrt/op_teller.cc +++ b/paddle/fluid/inference/tensorrt/op_teller.cc @@ -59,6 +59,8 @@ struct SimpleOpTypeSetTeller : public Teller { #if CUDA_VERSION >= 10020 teller_set.insert("reshape"); teller_set.insert("reshape2"); + int8_teller_set.insert("reshape"); + int8_teller_set.insert("reshape2"); #endif } @@ -91,7 +93,9 @@ struct SimpleOpTypeSetTeller : public Teller { "scale", "elementwise_mul", "conv2d_transpose", - "hard_swish"}; + "hard_swish", + "transpose", + "transpose2"}; std::unordered_set teller_set{"mul", "matmul", "conv2d", diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_quant_dequant_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_quant_dequant_pass.py index 114fa6478f8a6f2985be92dd13a3c4731bb207c3..9e1991ae1ae305afb3c473ccbaa469ef1130e189 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_quant_dequant_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_quant_dequant_pass.py @@ -86,15 +86,14 @@ class FCQuantDequantFusePassTRTDims3Cols2Test(QuantDequantTest): self.data = fluid.data( name='data', shape=[1, 28, 28], dtype='float32') self.label = fluid.data(name='label', shape=[1, 1], dtype='int64') - label_shape = fluid.layers.reshape(self.label, shape=[1, 1, 1]) fc_out = fluid.layers.fc(input=self.data, size=28, num_flatten_dims=2, bias_attr=False, act=None) - c_out = fluid.layers.reshape(fc_out, shape=[1, 1, 784]) + c_out = fluid.layers.reshape(fc_out, shape=[0, 784]) result = fluid.layers.relu(c_out) - loss = fluid.layers.cross_entropy(input=result, label=label_shape) + loss = fluid.layers.cross_entropy(input=result, label=self.label) avg_loss = fluid.layers.mean(loss) return avg_loss, result @@ -119,11 +118,11 @@ class FCQuantDequantFusePassTRTDims3Cols2Test(QuantDequantTest): self.dynamic_shape_params = FCQuantDequantFusePassTRTDims3Cols2Test.DynamicShapeParam( { 'data': [1, 28, 28], - 'reshape2_1.tmp_0': [1, 1, 784] + 'reshape2_0.tmp_0': [1, 784] }, {'data': [4, 28, 28], - 'reshape2_1.tmp_0': [4, 1, 784]}, - {'data': [1, 28, 28], - 'reshape2_1.tmp_0': [1, 1, 784]}, False) + 'reshape2_0.tmp_0': + [4, 784]}, {'data': [1, 28, 28], + 'reshape2_0.tmp_0': [1, 784]}, False) self.activation_quantize_type = 'moving_average_abs_max' self.weight_quantize_type = 'channel_wise_abs_max'