未验证 提交 cb4eea92 编写于 作者: W Wilber 提交者: GitHub

fix convert error. (#44307)

上级 5a312fb9
...@@ -119,6 +119,15 @@ bool WeightsShouldNotConvert(ir::Node* var_node) { ...@@ -119,6 +119,15 @@ bool WeightsShouldNotConvert(ir::Node* var_node) {
return false; return false;
} }
inline bool IsFloatVarType(framework::proto::VarType::Type type) {
if (type == framework::proto::VarType::FP16 ||
type == framework::proto::VarType::FP32 ||
type == framework::proto::VarType::BF16 ||
type == framework::proto::VarType::FP64)
return true;
return false;
}
void ConvertTensorDtype(framework::ir::Graph* graph, void ConvertTensorDtype(framework::ir::Graph* graph,
const std::unordered_set<std::string>& blacklist, const std::unordered_set<std::string>& blacklist,
bool keep_io_types, bool keep_io_types,
...@@ -146,8 +155,6 @@ void ConvertTensorDtype(framework::ir::Graph* graph, ...@@ -146,8 +155,6 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
if (!op_node->IsOp()) continue; if (!op_node->IsOp()) continue;
auto op_type = op_node->Op()->Type(); auto op_type = op_node->Op()->Type();
auto phi_op_type = phi::TransToPhiKernelName(op_type); auto phi_op_type = phi::TransToPhiKernelName(op_type);
// LOG(INFO) << "process op " << op_type << ", corresponding phi type is "
// << phi_op_type;
// 1. set input dtype. // 1. set input dtype.
if (op_type == "feed") { if (op_type == "feed") {
block_desc = op_node->Op()->Block(); block_desc = op_node->Op()->Block();
...@@ -175,12 +182,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph, ...@@ -175,12 +182,14 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
++num_low_precision; ++num_low_precision;
auto inputs = op_node->inputs; auto inputs = op_node->inputs;
for (auto* in_node : inputs) { for (auto* in_node : inputs) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var(); auto* in_var = in_node->Var();
if (in_var->Persistable() && if (in_var->Persistable() &&
in_var->GetDataType() == framework::proto::VarType::FP32) { in_var->GetDataType() == framework::proto::VarType::FP32) {
if (WeightsShouldNotConvert(in_node)) continue; if (WeightsShouldNotConvert(in_node)) continue;
in_var->SetDataType(to_type); in_var->SetDataType(to_type);
} else if (!in_var->Persistable() && } else if (!in_var->Persistable() &&
IsFloatVarType(in_var->GetDataType()) &&
in_var->GetDataType() != to_type) { in_var->GetDataType() != to_type) {
AddCastOp(graph, AddCastOp(graph,
in_node, in_node,
...@@ -193,6 +202,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph, ...@@ -193,6 +202,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
} }
} }
for (auto* out_node : op_node->outputs) { for (auto* out_node : op_node->outputs) {
if (out_node->IsCtrlVar()) continue;
auto* out_var = out_node->Var(); auto* out_var = out_node->Var();
if (out_var->GetDataType() == framework::proto::VarType::FP32) { if (out_var->GetDataType() == framework::proto::VarType::FP32) {
if (OutShouldNotConvert(out_node)) continue; if (OutShouldNotConvert(out_node)) continue;
...@@ -202,8 +212,9 @@ void ConvertTensorDtype(framework::ir::Graph* graph, ...@@ -202,8 +212,9 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
} else { } else {
auto inputs = op_node->inputs; auto inputs = op_node->inputs;
for (auto* in_node : inputs) { for (auto* in_node : inputs) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var(); auto* in_var = in_node->Var();
if (!in_var->Persistable() && if (!in_var->Persistable() && IsFloatVarType(in_var->GetDataType()) &&
in_var->GetDataType() != framework::proto::VarType::FP32) { in_var->GetDataType() != framework::proto::VarType::FP32) {
AddCastOp(graph, AddCastOp(graph,
in_node, in_node,
...@@ -224,6 +235,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph, ...@@ -224,6 +235,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
// trt pass should explicitle add cast op is input is bf16/tf32, etc. // trt pass should explicitle add cast op is input is bf16/tf32, etc.
if (op_node->Name() == "tensorrt_engine") continue; if (op_node->Name() == "tensorrt_engine") continue;
for (auto* in_node : op_node->inputs) { for (auto* in_node : op_node->inputs) {
if (in_node->IsCtrlVar()) continue;
auto* in_var = in_node->Var(); auto* in_var = in_node->Var();
if (in_var->GetDataType() == to_type) { if (in_var->GetDataType() == to_type) {
AddCastOp(graph, AddCastOp(graph,
...@@ -242,6 +254,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph, ...@@ -242,6 +254,7 @@ void ConvertTensorDtype(framework::ir::Graph* graph,
// 4. if output_op's dtype is not compatible to output dtype, then just insert // 4. if output_op's dtype is not compatible to output dtype, then just insert
// cast. // cast.
for (auto* node : output_nodes) { for (auto* node : output_nodes) {
if (node->IsCtrlVar()) continue;
auto var = node->Var(); auto var = node->Var();
if (keep_io_types && var->GetDataType() == to_type) { if (keep_io_types && var->GetDataType() == to_type) {
// fp16/bf16 -> fp32. // fp16/bf16 -> fp32.
...@@ -381,7 +394,7 @@ void ConvertToMixedPrecision(const std::string& model_file, ...@@ -381,7 +394,7 @@ void ConvertToMixedPrecision(const std::string& model_file,
std::unordered_set<std::string> weights_should_be_fp32; std::unordered_set<std::string> weights_should_be_fp32;
for (auto* node : graph->Nodes()) { for (auto* node : graph->Nodes()) {
if (!node->IsVar()) continue; if (!(node->IsVar() && !node->IsCtrlVar())) continue;
if (node->Var()->GetType() == if (node->Var()->GetType() ==
paddle::framework::proto::VarType::SELECTED_ROWS || paddle::framework::proto::VarType::SELECTED_ROWS ||
node->Var()->GetType() == node->Var()->GetType() ==
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册