From 1006383b8485c3409a9a7e09d9623df7e03f7364 Mon Sep 17 00:00:00 2001 From: Allen Guo Date: Mon, 17 Jan 2022 12:41:03 +0800 Subject: [PATCH] add TransferCastOpPass, DeleteScaleOpPass (#38985) Co-authored-by: Xiaobing Wang Co-authored-by: Allen Guo Co-authored-by: Zhixin Yao Co-authored-by: Haicheng Jiang Co-authored-by: Han Zhao Co-authored-by: Xiaobing Wang Co-authored-by: Zhixin Yao Co-authored-by: Haicheng Jiang Co-authored-by: Han Zhao --- .../framework/ir/ipu/delete_scale_op_pass.cc | 121 ++++++++++++++++++ ...ipu_pass_base.h => delete_scale_op_pass.h} | 15 +-- .../framework/ir/ipu/transfer_cast_op_pass.cc | 53 ++++++++ ...u_pass_base.cc => transfer_cast_op_pass.h} | 14 +- 4 files changed, 186 insertions(+), 17 deletions(-) create mode 100644 paddle/fluid/framework/ir/ipu/delete_scale_op_pass.cc rename paddle/fluid/framework/ir/ipu/{ipu_pass_base.h => delete_scale_op_pass.h} (69%) create mode 100644 paddle/fluid/framework/ir/ipu/transfer_cast_op_pass.cc rename paddle/fluid/framework/ir/ipu/{ipu_pass_base.cc => transfer_cast_op_pass.h} (75%) diff --git a/paddle/fluid/framework/ir/ipu/delete_scale_op_pass.cc b/paddle/fluid/framework/ir/ipu/delete_scale_op_pass.cc new file mode 100644 index 00000000000..933718bc1c1 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/delete_scale_op_pass.cc @@ -0,0 +1,121 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/delete_scale_op_pass.h" + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/variable_helper.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" + +namespace paddle { +namespace framework { +namespace ir { + +// this pass is used to optimize scale operator whose scale = 1 and bias = 0. +// scale will not be optimized if it is the only one operator in the graph. +void DeleteScaleOpPass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter DeleteScaleOpPass::ApplyImpl"; + VLOG(10) << "Raw Graph: "; + VLOG(10) << DebugString(graph); + + auto nodes = ir::TopologySortOperations(*graph); + + // delete op + for (auto node : nodes) { + if (!node->Op()) { + continue; + } + auto op = node->Op(); + if (op->Type() == "scale") { + auto input_var_node = node->inputs[0]; + auto output_var_node = node->outputs[0]; + // only optimize scale *1 + 0 + auto scale = BOOST_GET_CONST(float, op->GetAttr("scale")); + auto bias = BOOST_GET_CONST(float, op->GetAttr("bias")); + if (scale != 1 || bias != 0) { + return; + } + // only one op and it is scale , do not optimize + if (input_var_node->inputs.size() == 0 && + output_var_node->outputs.size() == 0) { + return; + } + VLOG(10) << "scale is to be optimized " + << " scale: " << scale << " bias: " << bias; + // build link in nodes + ir::Node* next_op_node = nullptr; + ir::Node* pre_op_node = nullptr; + // scale is not the last one + if (node->outputs[0]->outputs.size() > 0) { + next_op_node = node->outputs[0]->outputs[0]; + input_var_node->outputs.push_back(next_op_node); + next_op_node->inputs.push_back(input_var_node); + platform::ipu::DisConnectNodes(output_var_node, node); + platform::ipu::DisConnectNodes(input_var_node, node); + auto var_map = next_op_node->Op()->Inputs(); + for (auto& name_m : var_map) { + if (std::find(name_m.second.begin(), name_m.second.end(), + output_var_node->Name()) != name_m.second.end()) { + std::vector new_inputs; + for (auto& i_n : name_m.second) { + if (i_n != output_var_node->Name()) { + new_inputs.push_back(i_n); + } + } + new_inputs.push_back(input_var_node->Name()); + next_op_node->Op()->SetInput(name_m.first, new_inputs); + next_op_node->Op()->Flush(); + } + } + GraphSafeRemoveNodes(graph, {node, output_var_node}); + } else { // scale is not the first one + pre_op_node = node->inputs[0]->inputs[0]; + output_var_node->inputs.push_back(pre_op_node); + pre_op_node->outputs.push_back(output_var_node); + platform::ipu::DisConnectNodes(input_var_node, node); + platform::ipu::DisConnectNodes(output_var_node, node); + + auto var_map = pre_op_node->Op()->Inputs(); + std::vector new_outputs; + for (auto& name_m : var_map) { + if (std::find(name_m.second.begin(), name_m.second.end(), + input_var_node->Name()) != name_m.second.end()) { + for (auto& i_n : name_m.second) { + if (i_n != input_var_node->Name()) { + new_outputs.push_back(i_n); + } + } + new_outputs.push_back(output_var_node->Name()); + pre_op_node->Op()->SetOutput(name_m.first, new_outputs); + pre_op_node->Op()->Flush(); + } + } + GraphSafeRemoveNodes(graph, {node, input_var_node}); + } + } + } + VLOG(10) << "Post Graph: "; + VLOG(10) << DebugString(graph); + VLOG(10) << "leave DeleteScaleOpPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(delete_scale_op_pass, paddle::framework::ir::DeleteScaleOpPass); diff --git a/paddle/fluid/framework/ir/ipu/ipu_pass_base.h b/paddle/fluid/framework/ir/ipu/delete_scale_op_pass.h similarity index 69% rename from paddle/fluid/framework/ir/ipu/ipu_pass_base.h rename to paddle/fluid/framework/ir/ipu/delete_scale_op_pass.h index b56d3e4c65b..e2e56c9d5ef 100644 --- a/paddle/fluid/framework/ir/ipu/ipu_pass_base.h +++ b/paddle/fluid/framework/ir/ipu/delete_scale_op_pass.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -14,22 +14,15 @@ #pragma once -#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/pass.h" -#include "paddle/fluid/framework/scope.h" - +#include "paddle/fluid/platform/device/ipu/ipu_utils.h" namespace paddle { namespace framework { namespace ir { -class IPUPassBase : public Pass { - public: - void Init(const std::string& repr, Graph* graph) const; - virtual ~IPUPassBase() {} - +class DeleteScaleOpPass : public Pass { protected: - mutable Graph* graph_; - mutable std::string repr_; + void ApplyImpl(ir::Graph* graph) const override; }; } // namespace ir diff --git a/paddle/fluid/framework/ir/ipu/transfer_cast_op_pass.cc b/paddle/fluid/framework/ir/ipu/transfer_cast_op_pass.cc new file mode 100644 index 00000000000..e754ba72ad8 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/transfer_cast_op_pass.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/transfer_cast_op_pass.h" + +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/device/ipu/ipu_backend.h" + +namespace paddle { +namespace framework { +namespace ir { + +// Transfer the target dtype of Cast Op to FP16 if the original target is FP32 +// and enable FP16 mode. +void TransferCastOpPass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter TransferCastOpPass::ApplyImpl"; + VLOG(10) << "Raw Graph: "; + VLOG(10) << DebugString(graph); + + auto ipu_backend = platform::ipu::IpuBackend::GetInstance(); + auto enable_fp16 = ipu_backend->GetIpuStrategy()->enable_fp16; + if (enable_fp16) { + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()->Type() == "popart_cast") { + if (BOOST_GET_CONST(std::string, node->Op()->GetAttr("to")) == + "FLOAT") { + node->Op()->SetAttr("to", std::string("FLOAT16")); + } + } + } + } + + VLOG(10) << "Post Graph: "; + VLOG(10) << DebugString(graph); + VLOG(10) << "leave TransferCastOpPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(transfer_cast_op_pass, paddle::framework::ir::TransferCastOpPass); diff --git a/paddle/fluid/framework/ir/ipu/ipu_pass_base.cc b/paddle/fluid/framework/ir/ipu/transfer_cast_op_pass.h similarity index 75% rename from paddle/fluid/framework/ir/ipu/ipu_pass_base.cc rename to paddle/fluid/framework/ir/ipu/transfer_cast_op_pass.h index ba9233eeb8c..580fec10f2a 100644 --- a/paddle/fluid/framework/ir/ipu/ipu_pass_base.cc +++ b/paddle/fluid/framework/ir/ipu/transfer_cast_op_pass.h @@ -1,4 +1,4 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,16 +12,18 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" +#pragma once + +#include "paddle/fluid/framework/ir/pass.h" namespace paddle { namespace framework { namespace ir { -void IPUPassBase::Init(const std::string& repr, Graph* graph) const { - repr_ = repr; - graph_ = graph; -} +class TransferCastOpPass : public Pass { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; } // namespace ir } // namespace framework -- GitLab