diff --git a/paddle/fluid/framework/ir/ipu/avg_shard_pass.cc b/paddle/fluid/framework/ir/ipu/avg_shard_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..9dcbbb9c9856e5e40d4d79bebc3cdac42cf407c8 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/avg_shard_pass.cc @@ -0,0 +1,56 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/avg_shard_pass.h" + +#include "paddle/fluid/platform/device/ipu/ipu_backend.h" + +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void AvgShardPass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter AvgShardPass::ApplyImpl"; + + std::shared_ptr ipu_backend = + platform::ipu::IpuBackend::GetInstance(); + + if (ipu_backend->GetIpuStrategy()->need_avg_shard) { + VLOG(10) << "start AvgShardPass"; + auto nodes = ir::TopologySortOperations(*graph); + auto num_ipus = ipu_backend->GetIpuStrategy()->num_ipus; + + int shard_position = nodes.size() / num_ipus; + int index_and_stage = -1; + for (int i = 0; i < nodes.size(); i++) { + if ((i % shard_position) == 0 && index_and_stage < num_ipus - 1) { + index_and_stage++; + } + nodes[i]->Op()->SetAttr("ipu_index", index_and_stage); + nodes[i]->Op()->SetAttr("ipu_stage", index_and_stage); + } + VLOG(10) << "end AvgShardPass"; + } + + VLOG(10) << "leave AvgShardPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(avg_shard_pass, paddle::framework::ir::AvgShardPass); diff --git a/paddle/fluid/framework/ir/ipu/avg_shard_pass.h b/paddle/fluid/framework/ir/ipu/avg_shard_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..b13acbd198dd524d7e8c14eb03ca1dbc367b257a --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/avg_shard_pass.h @@ -0,0 +1,30 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class AvgShardPass : public IPUPassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.cc b/paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..5dcfddf6187f2b6e79d9f478489110a484cb0575 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.cc @@ -0,0 +1,133 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h" + +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void ForwardGraphExtractPass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter ForwardGraphExtractPass::ApplyImpl"; + + std::unordered_map> all_ops{ + {OpRole::kForward, {}}, {OpRole::kBackward, {}}, + {OpRole::kOptimize, {}}, {OpRole::kRPC, {}}, + {OpRole::kDist, {}}, {OpRole::kLRSched, {}}, + {OpRole::kLoss, {}}, {OpRole::kNotSpecified, {}}}; + for (auto* node : graph->Nodes()) { + if (!node->IsOp()) { + continue; + } + auto op_role = BOOST_GET_MUTABLE(int, node->Op()->GetAttr("op_role")); + if (op_role == static_cast(OpRole::kForward)) { + all_ops[OpRole::kForward].insert(node); + } else if (op_role == static_cast(OpRole::kBackward)) { + all_ops[OpRole::kBackward].insert(node); + } else if (op_role == static_cast(OpRole::kOptimize)) { + all_ops[OpRole::kOptimize].insert(node); + } else if (op_role == static_cast(OpRole::kRPC)) { + } else if (op_role == static_cast(OpRole::kDist)) { + } else if (op_role == static_cast(OpRole::kLRSched)) { + } else if (op_role == static_cast(OpRole::kLoss)) { + all_ops[OpRole::kLoss].insert(node); + } else if (op_role == static_cast(OpRole::kNotSpecified)) { + LOG(WARNING) << "Op: " << node->Name() << " OpRole is NotSpecified "; + } + } + + std::unordered_set forward_vars; + std::unordered_set backward_vars; + std::unordered_set control_vars; + // forward_vars + for (auto& nodes : std::array, 2>{ + all_ops[OpRole::kForward], all_ops[OpRole::kLoss]}) { + for (auto* node : nodes) { + for (auto* in_node : node->inputs) { + forward_vars.insert(in_node); + } + for (auto* out_node : node->outputs) { + forward_vars.insert(out_node); + } + } + } + // control_vars & backward_vars + for (auto* node : graph->Nodes()) { + if (!node->IsVar()) { + continue; + } + if (node->IsCtrlVar()) { + control_vars.insert(node); + } + for (auto* in_node : node->inputs) { + if (all_ops[OpRole::kOptimize].count(in_node)) { + backward_vars.insert(node); + } + } + } + // all removed node + std::unordered_set rm_nodes; + for (auto* node : graph->Nodes()) { + if (backward_vars.count(node)) { + rm_nodes.insert(node); + } else if (control_vars.count(node)) { + rm_nodes.insert(node); + } else if (all_ops[OpRole::kBackward].count(node)) { + rm_nodes.insert(node); + } else if (all_ops[OpRole::kForward].count(node) == 0 && + all_ops[OpRole::kLoss].count(node) == 0 && + forward_vars.count(node) == 0) { + rm_nodes.insert(node); + } else if (node->Name() == "feed" || node->Name() == "fetch") { + rm_nodes.insert(node); + } + } + + VLOG(10) << "Remove Node: "; + for (auto* node : rm_nodes) { + // rm node releations + for (auto* node_in : node->inputs) { + for (size_t i = 0; i < node_in->outputs.size(); ++i) { + if (node_in->outputs[i] == node) { + node_in->outputs.erase(node_in->outputs.begin() + i); + break; + } + } + } + for (auto* node_out : node->outputs) { + for (size_t i = 0; i < node_out->inputs.size(); ++i) { + if (node_out->inputs[i] == node) { + node_out->inputs.erase(node_out->inputs.begin() + i); + break; + } + } + } + VLOG(10) << "\t" << node->Name(); + graph->RemoveNode(node); + } + + VLOG(10) << "Post Graph: "; + VLOG(10) << DebugString(graph); + + VLOG(10) << "leave ForwardGraphExtractPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(forward_graph_extract_pass, + paddle::framework::ir::ForwardGraphExtractPass); diff --git a/paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h b/paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..afa9f1c15f2ab8c3c27dda9aa3485320725f0ea4 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h @@ -0,0 +1,31 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class ForwardGraphExtractPass : public IPUPassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/ipu/infer_shape_pass.cc b/paddle/fluid/framework/ir/ipu/infer_shape_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..ceef27ac1ce3c0a8ecd15f86a2dbae098059e0a8 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/infer_shape_pass.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/infer_shape_pass.h" + +#include "paddle/fluid/platform/device/ipu/ipu_backend.h" + +#include "paddle/fluid/framework/ddim.h" +#include "paddle/fluid/framework/ir/graph_helper.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/variable_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void InferShapePass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter InferShapePass::ApplyImpl"; + VLOG(10) << "Raw Graph: "; + VLOG(10) << DebugString(graph); + + std::shared_ptr ipu_backend = + platform::ipu::IpuBackend::GetInstance(); + auto batch_size = ipu_backend->GetIpuStrategy()->batch_size; + + auto feed_list = Get>("feed_list"); + for (auto node : graph->Nodes()) { + if (!node->IsVar()) { + continue; + } + bool is_feed = std::find(feed_list.begin(), feed_list.end(), + node->Name()) != feed_list.end(); + if (is_feed) { + auto input_shape = node->Var()->GetShape(); + if (input_shape[0] <= -1) { + input_shape[0] = batch_size; + node->Var()->SetShape(input_shape); + } + // int64->int32 + if (node->Var()->GetDataType() == proto::VarType::INT64) { + node->Var()->SetDataType(proto::VarType::INT32); + } + } + } + + // temp scope for shape inference + std::shared_ptr scope( + new paddle::framework::Scope()); + for (auto node : graph->Nodes()) { + if (!node->IsVar()) { + continue; + } + auto var_desc = node->Var(); + auto* ptr = scope->Var(var_desc->Name()); + paddle::framework::InitializeVariable(ptr, var_desc->GetType()); + + auto tensor = ptr->GetMutable(); + tensor->Resize(paddle::framework::make_ddim(var_desc->GetShape())); + } + + // infer shape + auto nodes = ir::TopologySortOperations(*graph); + for (auto node : nodes) { + auto op_desc = node->Op(); + auto op = paddle::framework::OpRegistry::CreateOp(*op_desc); + paddle::framework::RuntimeContext ctx(op->Inputs(), op->Outputs(), *scope); + op->RuntimeInferShape(*scope, paddle::platform::CPUPlace(), ctx); + + for (auto it = ctx.outputs.begin(); it != ctx.outputs.end(); it++) { + for (int i = 0; i < it->second.size(); i++) { + auto output_name = op_desc->Output(it->first)[i]; + auto dim = + it->second[i]->GetMutable()->dims(); + auto new_shape = paddle::framework::vectorize(dim); + for (auto output_node : node->outputs) { + if (output_node->Name() == output_name) { + output_node->Var()->SetShape(new_shape); + } + } + } + } + } + // release the temp scope + scope.reset(); + + VLOG(10) << "Post Graph: "; + VLOG(10) << DebugString(graph); + VLOG(10) << "leave InferShapePass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(infer_shape_pass, paddle::framework::ir::InferShapePass) + .RequirePassAttr("feed_list"); diff --git a/paddle/fluid/framework/ir/ipu/infer_shape_pass.h b/paddle/fluid/framework/ir/ipu/infer_shape_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..3e8148b7f066d9e6f3d4f2bd34b1d561bf4293a8 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/infer_shape_pass.h @@ -0,0 +1,30 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class InferShapePass : public IPUPassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/ipu/inference_postprocess_pass.cc b/paddle/fluid/framework/ir/ipu/inference_postprocess_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..616139a52ac06c35b252bd723730ab4ac96f4dfc --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/inference_postprocess_pass.cc @@ -0,0 +1,89 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/inference_postprocess_pass.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/device/ipu/ipu_backend.h" +#include "paddle/fluid/platform/device/ipu/ipu_strategy.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void InferencePostprocessPass::ApplyImpl(ir::Graph *graph) const { + VLOG(10) << "enter InferencePostprocessPass::ApplyImpl"; + + std::vector feed_list; + feed_list = Get>("feed_list"); + std::vector fetch_list; + fetch_list = Get>("fetch_list"); + + auto *feed_var = new paddle::framework::VarDesc("feed"); + feed_var->SetType(proto::VarType::FEED_MINIBATCH); + auto *feed_var_node = graph->CreateVarNode(feed_var); + + auto *fetch_var = new paddle::framework::VarDesc("fetch"); + fetch_var->SetType(proto::VarType::FETCH_LIST); + auto *fetch_var_node = graph->CreateVarNode(fetch_var); + + for (int i = 0; i < feed_list.size(); i++) { + for (auto node : graph->Nodes()) { + if (node->Name() == feed_list[i]) { + auto *op = new paddle::framework::OpDesc(); + op->SetType("feed"); + op->SetInput("X", {"feed"}); + op->SetOutput("Out", {node->Name()}); + op->SetAttr("col", i); + auto *op_node = graph->CreateOpNode(op); + node->inputs.push_back(op_node); + op_node->outputs.push_back(node); + feed_var_node->outputs.push_back(op_node); + op_node->inputs.push_back(feed_var_node); + break; + } + } + } + + for (int i = 0; i < fetch_list.size(); i++) { + for (auto node : graph->Nodes()) { + if (node->Name() == fetch_list[i]) { + auto *op = new paddle::framework::OpDesc(); + op->SetType("fetch"); + op->SetInput("X", {node->Name()}); + op->SetOutput("Out", {"fetch"}); + op->SetAttr("col", i); + auto *op_node = graph->CreateOpNode(op); + node->outputs.push_back(op_node); + op_node->inputs.push_back(node); + fetch_var_node->inputs.push_back(op_node); + op_node->outputs.push_back(fetch_var_node); + break; + } + } + } + + VLOG(10) << "leave InferencePostprocessPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(inference_postprocess_pass, + paddle::framework::ir::InferencePostprocessPass) + .RequirePassAttr("feed_list") + .RequirePassAttr("fetch_list"); diff --git a/paddle/fluid/framework/ir/ipu/inference_postprocess_pass.h b/paddle/fluid/framework/ir/ipu/inference_postprocess_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..e80e1905d4ad79c1df1caa4f9e8bd463dfa6654e --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/inference_postprocess_pass.h @@ -0,0 +1,30 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class InferencePostprocessPass : public IPUPassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/ipu/inference_process_pass.cc b/paddle/fluid/framework/ir/ipu/inference_process_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..d02dcce0cc62c8d96cab1ace860c7fbe913a6e10 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/inference_process_pass.cc @@ -0,0 +1,129 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/inference_process_pass.h" + +#include "paddle/fluid/platform/device/ipu/ipu_backend.h" +#include "paddle/fluid/platform/device/ipu/ipu_strategy.h" + +#include "paddle/fluid/framework/ir/fuse_pass_base.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace framework { +namespace ir { + +void InferenceProcessPass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter InferenceProcessPass::ApplyImpl"; + + // Get a new instance of ipu_backend + std::shared_ptr ipu_backend = + platform::ipu::IpuBackend::GetNewInstance(); + + // Set scope + auto& scope = graph->Get(kParamScopeAttr); + ipu_backend->SetScope(scope); + + // Set ipu_strategy + static std::shared_ptr ipu_strategy_instance_( + new platform::ipu::IpuStrategy()); + ipu_strategy_instance_->is_training = false; + auto num_ipus = graph->Get("num_ipus"); + ipu_strategy_instance_->num_ipus = num_ipus; + if (num_ipus > 1) { + ipu_strategy_instance_->popart_options_.virtualGraphMode = + platform::ipu::VirtualGraphMode::Manual; + } else { + ipu_strategy_instance_->popart_options_.virtualGraphMode = + platform::ipu::VirtualGraphMode::Off; + } + + auto enable_pipelining = graph->Get("enable_pipelining"); + ipu_strategy_instance_->popart_options_.enablePipelining = enable_pipelining; + if (enable_pipelining) { + auto batches_per_step = graph->Get("batches_per_step"); + PADDLE_ENFORCE_GE( + batches_per_step, num_ipus, + platform::errors::InvalidArgument("Batched per step should be equal or " + "greater than the number of IPUs")); + ipu_strategy_instance_->batches_per_step = batches_per_step; + } + ipu_strategy_instance_->batch_size = graph->Get("batch_size"); + ipu_strategy_instance_->need_avg_shard = graph->Get("need_avg_shard"); + + ipu_backend->SetIpuStrategy(*(ipu_strategy_instance_.get())); + + // Get feed_list and fetch list + std::vector feed_list = {}; + std::vector fetch_list = {}; + for (auto node : graph->Nodes()) { + if (node->Name() == "feed") { + if (node->IsOp()) { + feed_list.push_back(""); + } + } else if (node->Name() == "fetch") { + if (node->IsOp()) { + fetch_list.push_back(""); + } + } + } + for (auto node : graph->Nodes()) { + if (node->Name() == "feed") { + if (node->IsOp()) { + feed_list[BOOST_GET_CONST(int, node->Op()->GetAttr("col"))] = + node->outputs[0]->Name(); + } + } else if (node->Name() == "fetch") { + if (node->IsOp()) { + fetch_list[BOOST_GET_CONST(int, node->Op()->GetAttr("col"))] = + node->inputs[0]->Name(); + } + } + } + + // Run passes + std::vector graph_pass = {"forward_graph_extract_pass", + "infer_shape_pass", "avg_shard_pass", + "popart_canonicalization_pass"}; + std::vector compile_pass = { + "ipu_inplace_pass", "ipu_graph_builder_pass", "ipu_runtime_replacer_pass", + "inference_postprocess_pass"}; + for (auto pass_name : graph_pass) { + auto pass = PassRegistry::Instance().Get(pass_name); + if (pass_name == "infer_shape_pass") { + pass->Set("feed_list", new std::vector(feed_list.begin(), + feed_list.end())); + } + pass->Apply(graph); + } + + for (auto pass_name : compile_pass) { + auto pass = PassRegistry::Instance().Get(pass_name); + pass->Set("feed_list", + new std::vector(feed_list.begin(), feed_list.end())); + pass->Set("fetch_list", new std::vector(fetch_list.begin(), + fetch_list.end())); + pass->Apply(graph); + } + + VLOG(10) << "leave InferenceProcessPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(inference_process_pass, + paddle::framework::ir::InferenceProcessPass); diff --git a/paddle/fluid/framework/ir/ipu/inference_process_pass.h b/paddle/fluid/framework/ir/ipu/inference_process_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..bac0e88377f7c6d8ef124817256c5059345f714e --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/inference_process_pass.h @@ -0,0 +1,30 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class InferenceProcessPass : public IPUPassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc b/paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..5a53466089bc88fe25b3aaed54736f8e78994513 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.cc @@ -0,0 +1,52 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h" + +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/device/ipu/ipu_backend.h" + +namespace paddle { +namespace framework { +namespace ir { + +void IpuGraphBuilderPass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter IpuGraphBuilderPass::ApplyImpl"; + VLOG(10) << "Raw Graph: "; + VLOG(10) << DebugString(graph); + + std::vector feed_list; + feed_list = Get>("feed_list"); + + std::vector fetch_list; + fetch_list = Get>("fetch_list"); + + std::shared_ptr ipu_backend = + platform::ipu::IpuBackend::GetInstance(); + + ipu_backend->Compile(graph, feed_list, fetch_list); + + VLOG(10) << "Post Graph: "; + VLOG(10) << DebugString(graph); + VLOG(10) << "leave IpuGraphBuilderPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(ipu_graph_builder_pass, + paddle::framework::ir::IpuGraphBuilderPass) + .RequirePassAttr("feed_list") + .RequirePassAttr("fetch_list"); diff --git a/paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h b/paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..6237df36480335e3971aaffe120ca2374966fcb3 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/ipu_graph_builder_pass.h @@ -0,0 +1,31 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class IpuGraphBuilderPass : public IPUPassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/ipu/ipu_inplace_pass.cc b/paddle/fluid/framework/ir/ipu/ipu_inplace_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..d3f1f1633ffc947385c2924a9213de289cb3e3d8 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/ipu_inplace_pass.cc @@ -0,0 +1,85 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/ipu_inplace_pass.h" + +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +std::string GenerateVarName(Node *node) { + return node->Name() + "_" + std::to_string(node->id()); +} + +void IpuInplacePass::ApplyImpl(ir::Graph *graph) const { + // use this pass after forward_graph_extract_pass + // raise error if the inplaced var both in feed_list & fetch_list + VLOG(10) << "enter IpuInplacePass::ApplyImpl"; + VLOG(10) << "Raw Graph: "; + VLOG(10) << DebugString(graph); + + std::vector feed_list; + feed_list = Get>("feed_list"); + std::vector fetch_list; + fetch_list = Get>("fetch_list"); + + std::map var_name; + for (auto *node : graph->Nodes()) { + if (node->IsVar()) { + if (var_name.find(node->Name()) == var_name.end()) { + var_name.emplace(node->Name(), 1); + } else { + var_name[node->Name()]++; + } + } + } + + for (auto *node : graph->Nodes()) { + if (node->IsVar()) { + if (var_name[node->Name()] > 1) { + auto is_feed = (std::find(feed_list.begin(), feed_list.end(), + node->Name()) != feed_list.end()) && + (node->inputs.size() == 0); + auto is_fetch = (std::find(fetch_list.begin(), fetch_list.end(), + node->Name()) != fetch_list.end()) && + (node->outputs.size() == 0); + if (!is_feed && !is_fetch && !node->Var()->Persistable()) { + auto old_name = node->Name(); + auto new_name = GenerateVarName(node); + node->RenameVar(new_name); + for (auto *op_in : node->inputs) { + op_in->Op()->RenameOutput(old_name, new_name); + } + for (auto *op_out : node->outputs) { + op_out->Op()->RenameInput(old_name, new_name); + } + } + } + } + } + + VLOG(10) << "Post Graph: "; + VLOG(10) << DebugString(graph); + VLOG(10) << "leave IpuInplacePass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(ipu_inplace_pass, paddle::framework::ir::IpuInplacePass) + .RequirePassAttr("feed_list") + .RequirePassAttr("fetch_list"); diff --git a/paddle/fluid/framework/ir/ipu/ipu_inplace_pass.h b/paddle/fluid/framework/ir/ipu/ipu_inplace_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..86756276c8c3dcf6c010ad05b9a39bfa3bda71f5 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/ipu_inplace_pass.h @@ -0,0 +1,30 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class IpuInplacePass : public IPUPassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/ipu/ipu_pass_base.cc b/paddle/fluid/framework/ir/ipu/ipu_pass_base.cc new file mode 100644 index 0000000000000000000000000000000000000000..ba9233eeb8cb9512740b6ebdf90d07d8fee2d42a --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/ipu_pass_base.cc @@ -0,0 +1,28 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +void IPUPassBase::Init(const std::string& repr, Graph* graph) const { + repr_ = repr; + graph_ = graph; +} + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/ipu/ipu_pass_base.h b/paddle/fluid/framework/ir/ipu/ipu_pass_base.h new file mode 100644 index 0000000000000000000000000000000000000000..b56d3e4c65b1c0cf466824ae97c351c8f68c1380 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/ipu_pass_base.h @@ -0,0 +1,37 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/framework/scope.h" + +namespace paddle { +namespace framework { +namespace ir { + +class IPUPassBase : public Pass { + public: + void Init(const std::string& repr, Graph* graph) const; + virtual ~IPUPassBase() {} + + protected: + mutable Graph* graph_; + mutable std::string repr_; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle