diff --git a/paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.cc b/paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..a3e020714e1db90154b90b3785ea99e9eabd3256 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.cc @@ -0,0 +1,97 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h" + +#include "paddle/fluid/framework/ir/graph_pattern_detector.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" + +namespace paddle { +namespace framework { +namespace ir { + +void IpuRuntimeReplacerPass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter IpuRuntimeReplacerPass::ApplyImpl"; + VLOG(10) << "Raw Graph: "; + VLOG(10) << DebugString(graph); + + std::vector feed_list; + feed_list = Get>("feed_list"); + + std::vector fetch_list; + fetch_list = Get>("fetch_list"); + + framework::OpDesc ipu_rt_op_desc; + ipu_rt_op_desc.SetType("ipu_runtime"); + ipu_rt_op_desc.SetInput("FeedList", feed_list); + ipu_rt_op_desc.SetOutput("FetchList", fetch_list); + ipu_rt_op_desc.Flush(); + + // Create a new node for the ipu_runtime_op. + auto* ipu_rt_node = graph->CreateOpNode(&ipu_rt_op_desc); + + for (auto* node : graph->Nodes()) { + if (node->IsVar()) { + for (auto feed : feed_list) { + if (node->Name() == feed) { + IR_NODE_LINK_TO(node, ipu_rt_node); + } + } + for (auto fetch : fetch_list) { + if (node->Name() == fetch) { + IR_NODE_LINK_TO(ipu_rt_node, node); + } + } + } + } + + // set ipu_runtime_op dtype attr + if (fetch_list.size() == 1) { + for (auto* node : graph->Nodes()) { + if (node->IsVar()) { + for (auto fetch : fetch_list) { + if (node->Name() == fetch) { + ipu_rt_node->Op()->SetAttr("dtype", node->Var()->GetDataType()); + } + } + } + } + } + + // Remove unneeded nodes. + std::unordered_set marked_nodes; + for (auto* node : graph->Nodes()) { + if (node->IsOp()) { + auto* op_desc = node->Op(); + if (op_desc->Type() != "ipu_runtime") { + marked_nodes.insert(node); + } + } + } + + GraphSafeRemoveNodes(graph, marked_nodes); + + VLOG(10) << "Post Graph: "; + VLOG(10) << DebugString(graph); + VLOG(10) << "leave IpuRuntimeReplacerPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(ipu_runtime_replacer_pass, + paddle::framework::ir::IpuRuntimeReplacerPass) + .RequirePassAttr("feed_list") + .RequirePassAttr("fetch_list"); diff --git a/paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h b/paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..ba2cc8702fa4731a388986bee8436e6bff31c586 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h @@ -0,0 +1,31 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class IpuRuntimeReplacerPass : public IPUPassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc b/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..c6be2c775bd2116e11e24d17da225d2c76679399 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc @@ -0,0 +1,91 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h" + +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/device/ipu/ipu_backend.h" + +namespace paddle { +namespace framework { +namespace ir { + +void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter IpuOptimizerExtractPass::ApplyImpl"; + VLOG(10) << "Raw Graph: "; + VLOG(10) << DebugString(graph); + + auto ipu_backend = paddle::platform::ipu::IpuBackend::GetInstance(); + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()) { + int op_role = BOOST_GET_CONST( + int, node->Op()->GetAttr( + framework::OpProtoAndCheckerMaker::OpRoleAttrName())); + + // graph usually have multiple optimizer node for different parameter, + // and these node have the same type and attr value usually + if ((op_role == static_cast(framework::OpRole::kOptimize))) { + ipu_backend->GetExecutor().SetOptimizerType(node->Op()->Type()); + VLOG(10) << "found optimizer type: " << node->Op()->Type(); + + for (const std::string& attr_name : node->Op()->AttrNames()) { + auto attr_type = node->Op()->GetAttrType(attr_name); + // with adam, attr are float + if (attr_type == proto::AttrType::FLOAT) { + auto attr_value = + BOOST_GET_CONST(float, node->Op()->GetAttr(attr_name)); + ipu_backend->GetExecutor().SetOptimizerAttr(attr_name, attr_value); + } else { + VLOG(10) << "Skip " << attr_type; + } + } + + auto lr_var_name = node->Op()->Input("LearningRate"); + PADDLE_ENFORCE_EQ(lr_var_name.size(), 1u, + platform::errors::InvalidArgument( + "In op(%s), find input(LearningRate) failed.", + node->Op()->Type())); + + ipu_backend->GetExecutor().SetLRVarName(lr_var_name[0]); + } + + if ((op_role == static_cast(framework::OpRole::kLoss))) { + VLOG(10) << "found loss op type: " << node->Op()->Type(); + auto outputs = node->Op()->Outputs(); + PADDLE_ENFORCE_EQ( + outputs.size(), 1, + platform::errors::InvalidArgument("Can only support one loss key")); + + auto losses_name = outputs.begin()->second; + PADDLE_ENFORCE_EQ(losses_name.size(), 1, + platform::errors::InvalidArgument( + "Can only support one loss name")); + + ipu_backend->GetExecutor().SetLoss(losses_name[0]); + } + } + } + + VLOG(10) << "Post Graph: "; + VLOG(10) << DebugString(graph); + VLOG(10) << "leave IpuOptimizerExtractPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(optimizer_extract_pass, + paddle::framework::ir::IpuOptimizerExtractPass); diff --git a/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h b/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..fd274ded8f5bd1641bd1eb4e6999a6fa38dca090 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h @@ -0,0 +1,31 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class IpuOptimizerExtractPass : public IPUPassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.cc b/paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..c23bfdcb154f16a933c389d9d9032364995bd58c --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.cc @@ -0,0 +1,79 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h" +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/device/ipu/common.h" +#include "paddle/fluid/platform/device/ipu/ipu_backend.h" + +namespace paddle { +namespace framework { +namespace ir { + +using paddle::platform::ipu::IpuBackend; +using framework::ir::Graph; +using framework::ir::Node; + +void IpuOptimizerStateAlignPass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter IpuOptimizerStateAlignPass::ApplyImpl"; + VLOG(10) << "Raw Graph: "; + VLOG(10) << DebugString(graph); + + auto ipu_backend = IpuBackend::GetInstance(); + const auto* scope_ = ipu_backend->GetScope(); + + for (auto* node : graph->Nodes()) { + if (node->IsOp() && node->Op()) { + int op_role = BOOST_GET_CONST( + int, node->Op()->GetAttr( + framework::OpProtoAndCheckerMaker::OpRoleAttrName())); + + if ((op_role == static_cast(framework::OpRole::kOptimize))) { + auto inputs = node->Op()->Inputs(); + if (inputs.count(platform::ipu::sBeta1Pow)) { + auto var = scope_->GetVar(inputs.at(platform::ipu::sBeta1Pow)[0]); + auto data = var->GetMutable()->data(); + auto beta = BOOST_GET_CONST( + float, node->Op()->GetAttr(platform::ipu::sBeta1)); + + // ensure current save with beta1pow, rather than step. + // beta1pow = beta1 ^ (step + 1). Just set beta1pow because popart + // support single Step__ + bool save_with_beta1pow = (data[0] < 1.0f) && (data[0] > 0.0f); + float step = 0; + float beta_acc = beta; + while (beta_acc > data[0] && save_with_beta1pow) { + beta_acc *= beta; + step += 1; + } + + if (save_with_beta1pow) { + data[0] = step; + } + } + } + } + } + + VLOG(10) << "Post Graph: "; + VLOG(10) << DebugString(graph); + VLOG(10) << "leave IpuOptimizerStateAlignPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(optimizer_state_align_pass, + paddle::framework::ir::IpuOptimizerStateAlignPass); diff --git a/paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h b/paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..21a1017d88452aa950949c44a862b78c11bc5793 --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h @@ -0,0 +1,36 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +/* + * This pass should only affect optimizer that need bias correction, + * include Adam/Lamb. + */ + +class IpuOptimizerStateAlignPass : public IPUPassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc b/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..c97b7fd5bcb0cb7b6a4a49725d15e23832a9308f --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc @@ -0,0 +1,68 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h" + +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/post_canonicalization.h" + +namespace paddle { +namespace framework { +namespace ir { + +using framework::ir::Graph; +using framework::ir::Node; +using platform::ipu::SymbolHandler; + +void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter PopartCanonicalizationPass::ApplyImpl"; + VLOG(10) << "Raw Graph: "; + VLOG(10) << DebugString(graph); + + auto nodes = graph->Nodes(); + for (auto* node : nodes) { + if (!node->IsOp()) { + continue; + } + auto* op = node->Op(); + auto op_type = op->Type(); + + ir::Node* new_node = nullptr; + SymbolHandler handler = platform::ipu::GetHandler(op_type); + if (handler) { + VLOG(11) << "Raw Paddle Node:"; + VLOG(11) << node->Op()->Proto()->DebugString(); + new_node = handler(graph, node); + VLOG(11) << "Post Popart Node:"; + VLOG(11) << new_node->Op()->Proto()->DebugString(); + + platform::ipu::ClearNode(node); + graph->RemoveNode(node); + } else { + LOG(ERROR) << "Can not find OpHandler for op_type: " << op_type; + } + } + + VLOG(10) << "Post Graph: "; + VLOG(10) << DebugString(graph); + VLOG(10) << "leave PopartCanonicalizationPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(popart_canonicalization_pass, + paddle::framework::ir::PopartCanonicalizationPass); diff --git a/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h b/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..6690873f2a9ac2532f02b59b241a2858dc978beb --- /dev/null +++ b/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h @@ -0,0 +1,30 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h" + +namespace paddle { +namespace framework { +namespace ir { + +class PopartCanonicalizationPass : public IPUPassBase { + protected: + void ApplyImpl(ir::Graph* graph) const override; +}; + +} // namespace ir +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/CMakeLists.txt b/paddle/fluid/platform/device/ipu/CMakeLists.txt index 25629ba74d91523000b673a60e17cc7644ed752d..9be12cbf6d43769d19cf7f0b77d14d5ad02139ee 100644 --- a/paddle/fluid/platform/device/ipu/CMakeLists.txt +++ b/paddle/fluid/platform/device/ipu/CMakeLists.txt @@ -1,5 +1,5 @@ -# IPU IF(WITH_IPU) + FILE(GLOB POPART_CANONICALIZATION_SRC ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/device/ipu/popart_canonicalization/*.cc) cc_library(ipu_device SRCS device.cc DEPS enforce popart) cc_library(ipu_utils SRCS ipu_utils.cc DEPS memory framework_proto popart) cc_library(ipu_strategy SRCS ipu_strategy.cc DEPS popart graph framework_proto enforce) diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..5793c4c0e3ca69ec6eb9b7161dd62f95d0ba314a --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc @@ -0,0 +1,72 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *activation_op_handler(Graph *graph, Node *node, const std::string &type) { + auto new_node = CreateBaseOp(graph, node, type, {GetInputVarNode("X", node)}, + node->outputs); + return new_node; +} + +Node *relu_handler(Graph *graph, Node *node) { + return activation_op_handler(graph, node, "popart_relu"); +} + +Node *tanh_handler(Graph *graph, Node *node) { + return activation_op_handler(graph, node, "popart_tanh"); +} + +Node *log_handler(Graph *graph, Node *node) { + return activation_op_handler(graph, node, "popart_log"); +} + +Node *sigmoid_handler(Graph *graph, Node *node) { + return activation_op_handler(graph, node, "popart_sigmoid"); +} + +Node *sqrt_handler(Graph *graph, Node *node) { + return activation_op_handler(graph, node, "popart_sqrt"); +} + +Node *gelu_handler(Graph *graph, Node *node) { + return activation_op_handler(graph, node, "popart_gelu_v2"); +} + +Node *log_softmax_handler(Graph *graph, Node *node) { + auto axis = BOOST_GET_CONST(int, node->Op()->GetAttr("axis")); + auto new_softmax = CreateSoftmaxOpset11(graph, node, node->inputs, {}, axis); + return CreateBaseOp(graph, node, "popart_log", new_softmax->outputs, + node->outputs); +} + +REGISTER_HANDLER(relu, relu_handler); +REGISTER_HANDLER(tanh, tanh_handler); +REGISTER_HANDLER(log, log_handler); +REGISTER_HANDLER(sigmoid, sigmoid_handler); +REGISTER_HANDLER(sqrt, sqrt_handler); +REGISTER_HANDLER(gelu, gelu_handler); +REGISTER_HANDLER(log_softmax, log_softmax_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc new file mode 100644 index 0000000000000000000000000000000000000000..d46fc55ec6ce0de12e2a610be1937f9e3a948c02 --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc @@ -0,0 +1,185 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" + +namespace paddle { +namespace platform { +namespace ipu { + +// This avoids the static initialisation order fiasco, +std::unordered_map &SymbolHandlers() { + static std::unordered_map symbol_handlers; + return symbol_handlers; +} + +bool RegisterHandler(const std::string &symbol, const SymbolHandler &handler) { + if (SymbolHandlers().count(symbol) != 0) { + LOG(WARNING) << "Trying to register popart handler twice for operator: " + << symbol; + return false; + } + bool new_handler = SymbolHandlers().emplace(symbol, handler).second; + return new_handler; +} + +// Return a pointer to a handler if one is registered for this kind of node or +// an empty std::function otherwise. +SymbolHandler GetHandler(const std::string &kind) { + auto it = SymbolHandlers().find(kind); + if (it != SymbolHandlers().end()) { + return it->second; + } + return {}; +} + +void ConnectNodes(Node *first_node, Node *next_node) { + first_node->outputs.push_back(next_node); + next_node->inputs.push_back(first_node); +} + +void DisConnectNodes(Node *first_node, Node *next_node) { + auto rm_by_value = [&](std::vector &vec, Node *n) { + vec.erase(std::remove(vec.begin(), vec.end(), n), vec.end()); + }; + rm_by_value(first_node->outputs, next_node); + rm_by_value(next_node->inputs, first_node); + rm_by_value(first_node->inputs, next_node); + rm_by_value(next_node->outputs, first_node); +} + +void ClearNode(Node *node) { + auto rm_by_value = [&](std::vector &vec, Node *n) { + vec.erase(std::remove(vec.begin(), vec.end(), n), vec.end()); + }; + for (auto *node_in : node->inputs) { + rm_by_value(node_in->outputs, node); + } + for (auto *node_out : node->outputs) { + rm_by_value(node_out->inputs, node); + } +} + +void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op, + bool override) { + if (new_op->HasAttr(attr_name) && !override) { + return; + } + if (op->HasAttr(attr_name)) { + VLOG(10) << "Copying attr: " << attr_name << " from " << op->Type() + << " to " << new_op->Type(); + new_op->SetAttr(attr_name, op->GetAttr(attr_name)); + new_op->Flush(); + } +} + +const int VarType2OnnxDtype(const int type) { + auto dtype = static_cast(type); + switch (dtype) { + case framework::proto::VarType::BOOL: + return static_cast(ONNXDataType::BOOL); + case framework::proto::VarType::INT16: + return static_cast(ONNXDataType::INT16); + case framework::proto::VarType::INT32: + return static_cast(ONNXDataType::INT32); + case framework::proto::VarType::INT64: + return static_cast(ONNXDataType::INT64); + case framework::proto::VarType::FP16: + return static_cast(ONNXDataType::FLOAT16); + case framework::proto::VarType::FP32: + return static_cast(ONNXDataType::FLOAT); + case framework::proto::VarType::FP64: + return static_cast(ONNXDataType::DOUBLE); + case framework::proto::VarType::UINT8: + return static_cast(ONNXDataType::UINT8); + case framework::proto::VarType::INT8: + return static_cast(ONNXDataType::INT8); + case framework::proto::VarType::BF16: + return static_cast(ONNXDataType::BFLOAT16); + case framework::proto::VarType::COMPLEX64: + return static_cast(ONNXDataType::COMPLEX64); + case framework::proto::VarType::COMPLEX128: + return static_cast(ONNXDataType::COMPLEX128); + default: + PADDLE_THROW( + platform::errors::Unimplemented("Unsupported data type: %d.", dtype)); + } +} + +const std::string VarType2PopStr(const int type) { + auto dtype = static_cast(type); + switch (dtype) { + case framework::proto::VarType::UINT8: + return "UINT8"; + case framework::proto::VarType::INT8: + return "INT8"; + case framework::proto::VarType::INT16: + return "INT16"; + case framework::proto::VarType::INT32: + return "INT32"; + case framework::proto::VarType::INT64: + return "INT64"; + case framework::proto::VarType::BOOL: + return "BOOL"; + case framework::proto::VarType::FP64: + return "DOUBLE"; + case framework::proto::VarType::FP32: + return "FLOAT"; + case framework::proto::VarType::FP16: + return "FLOAT16"; + default: + PADDLE_THROW( + paddle::platform::errors::Unavailable("Unsupported data type.")); + } +} + +Node *GetInputVarNode(const std::string &input_name, const Node *op_node, + const int id) { + auto var_name = op_node->Op()->Input(input_name).at(id); + return GetInputVarNodeByVarName(var_name, op_node); +} + +Node *GetOutputVarNode(const std::string &output_name, const Node *op_node, + const int id) { + auto var_name = op_node->Op()->Output(output_name).at(id); + return GetOutputVarNodeByVarName(var_name, op_node); +} + +Node *GetInputVarNodeByVarName(const std::string &var_name, + const Node *op_node) { + for (auto *var : op_node->inputs) { + if (var->Name() == var_name) { + return var; + } + } + return nullptr; +} + +Node *GetOutputVarNodeByVarName(const std::string &var_name, + const Node *op_node) { + for (auto *var : op_node->outputs) { + if (var->Name() == var_name) { + return var; + } + } + return nullptr; +} + +const bool is_float_equal(float a, float b, float eps) { + return std::fabs(a - b) <= eps; +} + +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h new file mode 100644 index 0000000000000000000000000000000000000000..c1b2bd0c8b5fd454642be8eef733234c89c4d32a --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h @@ -0,0 +1,64 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/platform/device/ipu/ipu_utils.h" + +namespace paddle { +namespace platform { +namespace ipu { + +using framework::ir::Graph; +using framework::ir::Node; +using framework::OpDesc; + +#define REGISTER_HANDLER(name, func) \ + static bool __UNUSED_##name = \ + paddle::platform::ipu::RegisterHandler(#name, func) + +using SymbolHandler = std::function; + +std::unordered_map &SymbolHandlers(); + +bool RegisterHandler(const std::string &, const SymbolHandler &); + +SymbolHandler GetHandler(const std::string &); + +void ConnectNodes(Node *first_node, Node *next_node); +void DisConnectNodes(Node *first_node, Node *next_node); +void ClearNode(Node *node); +void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op, + bool override = false); + +const int VarType2OnnxDtype(const int type); +const std::string VarType2PopStr(const int type); + +Node *GetInputVarNode(const std::string &input_name, const Node *op_node, + const int id = 0); +Node *GetOutputVarNode(const std::string &output_name, const Node *op_node, + const int id = 0); +Node *GetInputVarNodeByVarName(const std::string &var_name, + const Node *op_node); +Node *GetOutputVarNodeByVarName(const std::string &var_name, + const Node *op_node); + +const bool is_float_equal(float a, float b, float eps = 1e-8); + +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/elementwise_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/elementwise_ops.cc new file mode 100644 index 0000000000000000000000000000000000000000..f0c19cac3a6c3f5db76ab70cabb7d49449e030a1 --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/elementwise_ops.cc @@ -0,0 +1,108 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *elementwise_op_handler(Graph *graph, Node *node, + const std::string &type) { + auto *op = node->Op(); + auto x_shape = GetInputVarNode("X", node)->Var()->GetShape(); + int64_t x_rank = x_shape.size(); + auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape(); + int64_t y_rank = y_shape.size(); + + auto axis = BOOST_GET_CONST(int, op->GetAttr("axis")); + if (axis == -1 || axis == x_rank - 1 || x_rank == y_rank) { + auto new_node = + CreateBaseOp(graph, node, type, + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, + node->outputs); + return new_node; + } else { + auto y_new_shape = std::vector(x_rank, 1); + for (int i = axis; i < axis + y_rank; ++i) { + y_new_shape[i] = y_shape[i - axis]; + } + auto attrs = AttributeMap{ + {"value", y_new_shape}, + {"dims", std::vector{x_rank}}, + {"dtype", ONNXDataType::INT64}, + }; + // constant + auto new_node_const = CreateConst(graph, node, {}, {}, attrs); + // reshape + auto new_node_reshape = CreateBaseOp( + graph, node, "popart_reshape", + {GetInputVarNode("Y", node), new_node_const->outputs[0]}, {}); + // elementwise_op + auto new_node = + CreateBaseOp(graph, node, type, + {GetInputVarNode("X", node), new_node_reshape->outputs[0]}, + node->outputs); + return new_node; + } +} + +Node *elementwise_add_handler(Graph *graph, Node *node) { + return elementwise_op_handler(graph, node, "popart_add"); +} + +Node *elementwise_sub_handler(Graph *graph, Node *node) { + return elementwise_op_handler(graph, node, "popart_sub"); +} + +Node *elementwise_div_handler(Graph *graph, Node *node) { + return elementwise_op_handler(graph, node, "popart_div"); +} + +Node *elementwise_mul_handler(Graph *graph, Node *node) { + return elementwise_op_handler(graph, node, "popart_mul"); +} + +Node *elementwise_min_handler(Graph *graph, Node *node) { + return elementwise_op_handler(graph, node, "popart_min"); +} + +Node *elementwise_max_handler(Graph *graph, Node *node) { + return elementwise_op_handler(graph, node, "popart_max"); +} + +Node *elementwise_pow_handler(Graph *graph, Node *node) { + return elementwise_op_handler(graph, node, "popart_pow"); +} + +Node *elementwise_mod_handler(Graph *graph, Node *node) { + return elementwise_op_handler(graph, node, "popart_mod"); +} + +REGISTER_HANDLER(elementwise_add, elementwise_add_handler); +REGISTER_HANDLER(elementwise_sub, elementwise_sub_handler); +REGISTER_HANDLER(elementwise_div, elementwise_div_handler); +REGISTER_HANDLER(elementwise_mul, elementwise_mul_handler); +REGISTER_HANDLER(elementwise_min, elementwise_min_handler); +REGISTER_HANDLER(elementwise_max, elementwise_max_handler); +REGISTER_HANDLER(elementwise_pow, elementwise_pow_handler); +REGISTER_HANDLER(elementwise_mod, elementwise_mod_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle