未验证 提交 1bcf437a 编写于 作者: 周周周 提交者: GitHub

[Paddle-TRT]forbid fp64 enter into trt (#54561)

* forbid fp64 enter into trt
上级 6bbe92a1
...@@ -86,6 +86,49 @@ struct SimpleOpTypeSetTeller : public Teller { ...@@ -86,6 +86,49 @@ struct SimpleOpTypeSetTeller : public Teller {
bool use_no_calib_int8 = false, bool use_no_calib_int8 = false,
bool with_dynamic_shape = false) override { bool with_dynamic_shape = false) override {
const std::string op_type = desc.Type(); const std::string op_type = desc.Type();
std::unordered_set<std::string> control_set = {"conditional_block",
"while"};
std::unordered_set<std::string> feed_fetch_set = {"feed", "fetch"};
if (control_set.find(op_type) != control_set.end()) {
return false;
}
if (feed_fetch_set.find(op_type) != feed_fetch_set.end()) {
return false;
}
// Dont.t allow fp64!
{
auto inputs = desc.Inputs();
for (auto iter : inputs) {
for (auto var_name : iter.second) {
auto* block = desc.Block();
if (block) {
auto* var_desc = block->FindVar(var_name);
auto dtype = var_desc->GetDataType();
if (dtype == framework::proto::VarType::FP64) {
return false;
}
}
}
}
auto outputs = desc.Outputs();
for (auto iter : outputs) {
for (auto var_name : iter.second) {
auto* block = desc.Block();
if (block) {
auto* var_desc = block->FindVar(var_name);
auto dtype = var_desc->GetDataType();
if (dtype == framework::proto::VarType::FP64) {
return false;
}
}
}
}
}
// do not support the op which is labeled the `skip_quant` // do not support the op which is labeled the `skip_quant`
if ((desc.HasAttr("namescope") && if ((desc.HasAttr("namescope") &&
PADDLE_GET_CONST(std::string, desc.GetAttr("op_namescope")) == PADDLE_GET_CONST(std::string, desc.GetAttr("op_namescope")) ==
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册