提交 14737e19 编写于 作者: Z Zhaolong Xing 提交者: GitHub

[cherry-pick] [Refine Paddle-TRT INT8]: Support PaddleSlim's Resnet50 (#22485)

test=develop
上级 d2d4a02c
...@@ -34,10 +34,13 @@ using framework::ir::Node; ...@@ -34,10 +34,13 @@ using framework::ir::Node;
void analysis::TensorRtSubgraphPass::ApplyImpl( void analysis::TensorRtSubgraphPass::ApplyImpl(
framework::ir::Graph *graph) const { framework::ir::Graph *graph) const {
framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph); framework::ir::FusePassBase::Init("tensorrt_subgraph_pass", graph);
auto enable_int8 = Get<bool>("enable_int8");
auto teller = [](const framework::ir::Node *node) { auto use_calib_mode = Get<bool>("use_calib_mode");
bool no_calib_int8 = enable_int8 && !(use_calib_mode);
auto teller = [&](const framework::ir::Node *node) {
if (!node->IsOp() || !node->Op()) return false; if (!node->IsOp() || !node->Op()) return false;
return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op()); return tensorrt::OpTeller::Global().Tell(node->Op()->Type(), *node->Op(),
no_calib_int8);
}; };
framework::ir::SubGraphFuser fuser( framework::ir::SubGraphFuser fuser(
......
...@@ -98,6 +98,14 @@ class Pool2dOpConverter : public OpConverter { ...@@ -98,6 +98,14 @@ class Pool2dOpConverter : public OpConverter {
nvinfer1::ILayer *layer = nullptr; nvinfer1::ILayer *layer = nullptr;
if (op_desc.HasAttr("enable_int8")) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("X_scale"));
float input_scale = boost::get<float>(op_desc.GetAttr("X_scale"));
engine_->SetTensorDynamicRange(input1, input_scale);
#endif
}
if (global_pooling == true) { if (global_pooling == true) {
nv_ksize.d[0] = input_shape.d[input_dims - 2]; nv_ksize.d[0] = input_shape.d[input_dims - 2];
nv_ksize.d[1] = input_shape.d[input_dims - 1]; nv_ksize.d[1] = input_shape.d[input_dims - 1];
...@@ -159,14 +167,6 @@ class Pool2dOpConverter : public OpConverter { ...@@ -159,14 +167,6 @@ class Pool2dOpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "pool2d", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "pool2d", {output_name}, test_mode);
if (op_desc.HasAttr("enable_int8")) {
#if IS_TRT_VERSION_GE(5000)
CHECK(op_desc.HasAttr("X_scale"));
float input_scale = boost::get<float>(op_desc.GetAttr("X_scale"));
engine_->SetTensorDynamicRange(input1, input_scale);
#endif
}
} }
}; };
......
...@@ -36,12 +36,8 @@ class SoftMaxOpConverter : public OpConverter { ...@@ -36,12 +36,8 @@ class SoftMaxOpConverter : public OpConverter {
auto output_name = op_desc.Output("Out")[0]; auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode); RreplenishLayerAndOutput(layer, "softmax", {output_name}, test_mode);
if (op_desc.HasAttr("out_scale")) { // The trt will not run int for softmax.
#if IS_TRT_VERSION_GE(5000) engine_->SetTensorDynamicRange(input1, 1.0);
float out_scale = boost::get<float>(op_desc.GetAttr("out_scale"));
engine_->SetTensorDynamicRange(layer->getOutput(0), out_scale);
#endif
}
} }
}; };
......
...@@ -26,9 +26,13 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -26,9 +26,13 @@ struct SimpleOpTypeSetTeller : public Teller {
#endif #endif
} }
bool operator()(const std::string& op_type, bool operator()(const std::string& op_type, const framework::OpDesc& desc,
const framework::OpDesc& desc) override { bool use_no_calib_int8) override {
return teller_set.count(op_type); if (use_no_calib_int8) {
return int8_teller_set.count(op_type);
} else {
return teller_set.count(op_type);
}
} }
private: private:
...@@ -59,13 +63,22 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -59,13 +63,22 @@ struct SimpleOpTypeSetTeller : public Teller {
"layer_norm", "layer_norm",
"multihead_matmul", "multihead_matmul",
}}; }};
// use this set for no calib int8.
std::unordered_set<std::string> int8_teller_set{
{"mul", "conv2d", "pool2d", "relu", "depthwise_conv2d", "softmax",
"batch_norm", "elementwise_add", "leaky_relu", "fc"}};
}; };
bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) { bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc,
bool use_no_calib_int8) {
// do not support the op which is labeled the `skip_quant` // do not support the op which is labeled the `skip_quant`
if (desc.HasAttr("op_namescope") && if ((desc.HasAttr("namescope") &&
boost::get<std::string>(desc.GetAttr("op_namescope")) == "/skip_quant_2/") boost::get<std::string>(desc.GetAttr("op_namescope")) ==
"/skip_quant_2/") ||
desc.HasAttr("skip_quant"))
return false; return false;
for (auto& teller : tellers_) { for (auto& teller : tellers_) {
if (op_type == "pool2d" || op_type == "conv2d" || if (op_type == "pool2d" || op_type == "conv2d" ||
op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") { op_type == "depthwise_conv2d" || op_type == "conv2d_transpose") {
...@@ -73,7 +86,7 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) { ...@@ -73,7 +86,7 @@ bool OpTeller::Tell(const std::string& op_type, const framework::OpDesc& desc) {
boost::get<std::vector<int>>(desc.GetAttr("paddings")); boost::get<std::vector<int>>(desc.GetAttr("paddings"));
if (paddings.size() > 2) return false; if (paddings.size() > 2) return false;
} }
if ((*teller)(op_type, desc)) return true; if ((*teller)(op_type, desc, use_no_calib_int8)) return true;
} }
return false; return false;
} }
......
...@@ -31,7 +31,8 @@ namespace tensorrt { ...@@ -31,7 +31,8 @@ namespace tensorrt {
*/ */
struct Teller { struct Teller {
virtual bool operator()(const std::string& op_type, virtual bool operator()(const std::string& op_type,
const framework::OpDesc& desc) = 0; const framework::OpDesc& desc,
bool use_no_calib_int8) = 0;
virtual ~Teller() = default; virtual ~Teller() = default;
}; };
...@@ -57,7 +58,8 @@ class OpTeller { ...@@ -57,7 +58,8 @@ class OpTeller {
return *x; return *x;
} }
bool Tell(const std::string& op_type, const framework::OpDesc& desc); bool Tell(const std::string& op_type, const framework::OpDesc& desc,
bool use_no_calib_int8 = false);
private: private:
OpTeller(); OpTeller();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册