未验证 提交 c3a69111 编写于 作者: Z zhoutianzi666 提交者: GitHub

[Paddle Inference] allow fold fill_constant && allow nms3 into trt in int8 model (#47551)

* allow fold fill_constant && allow nms3 into trt in int8 model
* use unordered_map
* fix CI failing
上级 51507430
......@@ -64,8 +64,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
platform::errors::Fatal(
"scope must not be null when applying constant floding."));
// Now, I don't want to fold fill_constant op in Paddle-TRT
std::vector<std::string> blacklist{"fill_constant", "feed"};
std::vector<std::string> blacklist{"feed"};
auto op_node_sorted = framework::ir::TopologyVarientSort(
*graph, static_cast<framework::ir::SortKind>(0));
......@@ -78,7 +77,7 @@ void ConstantFoldingPass::ApplyImpl(ir::Graph *graph) const {
bool input_persis = true;
// map is used to record how many time a name string occures in the whole
// graph's nodes
std::map<std::string, int> map;
std::unordered_map<std::string, int> map;
for (auto in_node : op_node->inputs) {
map[in_node->Name()] = 0;
if (!in_node->Var()->Persistable()) {
......
......@@ -42,6 +42,8 @@ struct SimpleOpTypeSetTeller : public Teller {
teller_set.insert("group_norm");
teller_set.insert("multiclass_nms3");
teller_set.insert("multiclass_nms");
int8_teller_set.insert("multiclass_nms3");
int8_teller_set.insert("multiclass_nms");
#endif
#if IS_TRT_VERSION_GE(7000)
teller_set.insert("tile");
......
......@@ -40,7 +40,7 @@ TEST(Analyzer_seq_pool1_fuse_statis, fuse_statis) {
EXPECT_EQ(fuse_statis.at("squared_mat_sub_fuse"), 0);
EXPECT_EQ(fuse_statis.at("repeated_fc_relu_fuse"), 2);
LOG(INFO) << "num_ops: " << num_ops;
EXPECT_EQ(num_ops, 183);
EXPECT_EQ(num_ops, 181);
}
} // namespace seq_pool1_tester
......
......@@ -246,9 +246,10 @@ class TrtConvertExpandV2Test2(TrtLayerAutoScanTest):
# for dynamic_shape
generate_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (1, 3), 1e-5
# fill_constant will be folded by constnt folding pass!
yield self.create_inference_config(), (0, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (1, 3), 1e-3
yield self.create_inference_config(), (0, 3), 1e-3
def add_skip_trt_case(self):
pass
......@@ -389,9 +390,10 @@ class TrtConvertExpandV2Test3(TrtLayerAutoScanTest):
# for dynamic_shape
generate_dynamic_shape()
self.trt_param.precision = paddle_infer.PrecisionType.Float32
yield self.create_inference_config(), (4, 3), 1e-5
# fill_constant will be folded by constnt folding pass!
yield self.create_inference_config(), (0, 3), 1e-5
self.trt_param.precision = paddle_infer.PrecisionType.Half
yield self.create_inference_config(), (4, 3), 1e-3
yield self.create_inference_config(), (0, 3), 1e-3
def add_skip_trt_case(self):
pass
......
......@@ -21,7 +21,7 @@ from functools import partial
from typing import Any, Dict, List
class TrtConvertSplitTest(TrtLayerAutoScanTest):
class TrtConvertFillConstantTest(TrtLayerAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True
......@@ -36,7 +36,7 @@ class TrtConvertSplitTest(TrtLayerAutoScanTest):
return np.array([4]).astype(np.int32)
for shape in [[2, 3, 4]]:
for num_input in [0, 1, 2, 3]:
for num_input in [0, 1, 2]:
for dtype in [5, 2, 3]:
for str_value in ["2", "23", "-1"]:
self.num_input = num_input
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册