未验证 提交 c3a69111 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference] allow fold fill_constant && allow nms3 into trt in int8 model (#47551)

* allow fold fill_constant && allow nms3 into trt in int8 model
* use unordered_map
* fix CI failing
上级 51507430
...@@ -64,8 +64,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const { ...@@ -64,8 +64,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
platform::errors::Fatal( platform::errors::Fatal(
"scope must not be null when applying constant floding.")); "scope must not be null when applying constant floding."));
// Now, I don't want to fold fill_constant op in Paddle-TRT std::vector<std::string> blacklist{"feed"};
std::vector<std::string> blacklist{"fill_constant", "feed"};
auto op_node_sorted = framework::ir::TopologyVarientSort( auto op_node_sorted = framework::ir::TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(0)); *graph, static_cast<framework::ir::SortKind>(0));
...@@ -78,7 +77,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const { ...@@ -78,7 +77,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
bool input_persis = true; bool input_persis = true;
// map is used to record how many time a name string occures in the whole // map is used to record how many time a name string occures in the whole
// graph's nodes // graph's nodes
std::map<std::string, int> map; std::unordered_map<std::string, int> map;
for (auto in_node : op_node->inputs) { for (auto in_node : op_node->inputs) {
map[in_node->Name()] = 0; map[in_node->Name()] = 0;
if (!in_node->Var()->Persistable()) { if (!in_node->Var()->Persistable()) {
......
...@@ -42,6 +42,8 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -42,6 +42,8 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("group_norm"); teller_set.insert("group_norm");
teller_set.insert("multiclass_nms3"); teller_set.insert("multiclass_nms3");
teller_set.insert("multiclass_nms"); teller_set.insert("multiclass_nms");
int8_teller_set.insert("multiclass_nms3");
int8_teller_set.insert("multiclass_nms");
#endif #endif
#if IS_TRT_VERSION_GE(7000) #if IS_TRT_VERSION_GE(7000)
teller_set.insert("tile"); teller_set.insert("tile");
......
...@@ -40,7 +40,7 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) { ...@@ -40,7 +40,7 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) {
EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 0); EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 0);
EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2); EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops; LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 183); EXPECT_EQ(num_ops, 181);
} }
} // namespace seq_pool1_tester } // namespace seq_pool1_tester
......
...@@ -246,9 +246,10 @@ class TrtConvertExpandV2Test2(TrtLayerAutoScanTest): ...@@ -246,9 +246,10 @@ class TrtConvertExpandV2Test2(TrtLayerAutoScanTest):
# for dynamic_shape # for dynamic_shape
generate_dynamic_shape() generate_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 3), 1e-5 # fill_constant will be folded by constnt folding pass!
yield self.create_inference_config(), (0, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), 1e-3 yield self.create_inference_config(), (0, 3), 1e-3
def add_skip_trt_case(self): def add_skip_trt_case(self):
pass pass
...@@ -389,9 +390,10 @@ class TrtConvertExpandV2Test3(TrtLayerAutoScanTest): ...@@ -389,9 +390,10 @@ class TrtConvertExpandV2Test3(TrtLayerAutoScanTest):
# for dynamic_shape # for dynamic_shape
generate_dynamic_shape() generate_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32 self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (4, 3), 1e-5 # fill_constant will be folded by constnt folding pass!
yield self.create_inference_config(), (0, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (4, 3), 1e-3 yield self.create_inference_config(), (0, 3), 1e-3
def add_skip_trt_case(self): def add_skip_trt_case(self):
pass pass
......
...@@ -21,7 +21,7 @@ from functools import partial ...@@ -21,7 +21,7 @@ from functools import partial
from typing import Any, Dict, List from typing import Any, Dict, List
class TrtConvertSplitTest(TrtLayerAutoScanTest): class TrtConvertFillConstantTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True return True
...@@ -36,7 +36,7 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest): ...@@ -36,7 +36,7 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
return np.array([4]).astype(np.int32) return np.array([4]).astype(np.int32)
for shape in [[2, 3, 4]]: for shape in [[2, 3, 4]]:
for num_input in [0, 1, 2, 3]: for num_input in [0, 1, 2]:
for dtype in [5, 2, 3]: for dtype in [5, 2, 3]:
for str_value in ["2", "23", "-1"]: for str_value in ["2", "23", "-1"]:
self.num_input = num_input self.num_input = num_input
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册