未验证 提交 cb4eea92 编写于 作者: W Wilber 提交者: GitHub

fix convert error. (#44307)

上级 5a312fb9
......@@ -119,6 +119,15 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
return false;
}
inline bool IsFloatVarType(framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP16 ||
type == framework::proto::VarType::FP32 ||
type == framework::proto::VarType::BF16 ||
type == framework::proto::VarType::FP64)
return true;
return false;
}
void ConvertTensorDtype(framework::ir::Graph* graph,
const std::unordered_set<std::string>& blacklist,
bool keep_io_types,
......@@ -146,8 +155,6 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type();
auto phi_op_type = phi::TransToPhiKernelName(op_type);
// LOG(INFO) << "process op " << op_type << ", corresponding phi type is "
// << phi_op_type;
// 1. set input dtype.
if (op_type == "feed") {
block_desc = op_node->Op()->Block();
......@@ -175,12 +182,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
++num_low_precision;
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var();
if (in_var->Persistable() &&
in_var->GetDataType() == framework::proto::VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) continue;
in_var->SetDataType(to_type);
} else if (!in_var->Persistable() &&
IsFloatVarType(in_var->GetDataType()) &&
in_var->GetDataType() != to_type) {
AddCastOp(graph,
in_node,
......@@ -193,6 +202,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
}
}
for (auto* out_node : op_node->outputs) {
if (out_node->IsCtrlVar()) continue;
auto* out_var = out_node->Var();
if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (OutShouldNotConvert(out_node)) continue;
......@@ -202,8 +212,9 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
} else {
auto inputs = op_node->inputs;
for (auto* in_node : inputs) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var();
if (!in_var->Persistable() &&
if (!in_var->Persistable() && IsFloatVarType(in_var->GetDataType()) &&
in_var->GetDataType() != framework::proto::VarType::FP32) {
AddCastOp(graph,
in_node,
......@@ -224,6 +235,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
// trt pass should explicitle add cast op is input is bf16/tf32, etc.
if (op_node->Name() == "tensorrt_engine") continue;
for (auto* in_node : op_node->inputs) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var();
if (in_var->GetDataType() == to_type) {
AddCastOp(graph,
......@@ -242,6 +254,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
// 4. if output_op's dtype is not compatible to output dtype, then just insert
// cast.
for (auto* node : output_nodes) {
if (node->IsCtrlVar()) continue;
auto var = node->Var();
if (keep_io_types && var->GetDataType() == to_type) {
// fp16/bf16 -> fp32.
......@@ -381,7 +394,7 @@ void ConvertToMixedPrecision(const std::string& model_file,
std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : graph->Nodes()) {
if (!node->IsVar()) continue;
if (!(node->IsVar() && !node->IsCtrlVar())) continue;
if (node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() ==
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册