未验证 提交 8b30c1ec 编写于 作者: J jianghaicheng 提交者: GitHub

add popart_canonicalization p2 (#37965)

上级 8adaa0f0
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/ipu_runtime_replacer_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void IpuRuntimeReplacerPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter IpuRuntimeReplacerPass::ApplyImpl";
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);
std::vector<std::string> feed_list;
feed_list = Get<std::vector<std::string>>("feed_list");
std::vector<std::string> fetch_list;
fetch_list = Get<std::vector<std::string>>("fetch_list");
framework::OpDesc ipu_rt_op_desc;
ipu_rt_op_desc.SetType("ipu_runtime");
ipu_rt_op_desc.SetInput("FeedList", feed_list);
ipu_rt_op_desc.SetOutput("FetchList", fetch_list);
ipu_rt_op_desc.Flush();
// Create a new node for the ipu_runtime_op.
auto* ipu_rt_node = graph->CreateOpNode(&ipu_rt_op_desc);
for (auto* node : graph->Nodes()) {
if (node->IsVar()) {
for (auto feed : feed_list) {
if (node->Name() == feed) {
IR_NODE_LINK_TO(node, ipu_rt_node);
}
}
for (auto fetch : fetch_list) {
if (node->Name() == fetch) {
IR_NODE_LINK_TO(ipu_rt_node, node);
}
}
}
}
// set ipu_runtime_op dtype attr
if (fetch_list.size() == 1) {
for (auto* node : graph->Nodes()) {
if (node->IsVar()) {
for (auto fetch : fetch_list) {
if (node->Name() == fetch) {
ipu_rt_node->Op()->SetAttr("dtype", node->Var()->GetDataType());
}
}
}
}
}
// Remove unneeded nodes.
std::unordered_set<const Node*> marked_nodes;
for (auto* node : graph->Nodes()) {
if (node->IsOp()) {
auto* op_desc = node->Op();
if (op_desc->Type() != "ipu_runtime") {
marked_nodes.insert(node);
}
}
}
GraphSafeRemoveNodes(graph, marked_nodes);
VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
VLOG(10) << "leave IpuRuntimeReplacerPass::ApplyImpl";
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(ipu_runtime_replacer_pass,
paddle::framework::ir::IpuRuntimeReplacerPass)
.RequirePassAttr("feed_list")
.RequirePassAttr("fetch_list");
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class IpuRuntimeReplacerPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/optimizer_extract_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
namespace paddle {
namespace framework {
namespace ir {
void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter IpuOptimizerExtractPass::ApplyImpl";
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);
auto ipu_backend = paddle::platform::ipu::IpuBackend::GetInstance();
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()) {
int op_role = BOOST_GET_CONST(
int, node->Op()->GetAttr(
framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
// graph usually have multiple optimizer node for different parameter,
// and these node have the same type and attr value usually
if ((op_role == static_cast<int>(framework::OpRole::kOptimize))) {
ipu_backend->GetExecutor().SetOptimizerType(node->Op()->Type());
VLOG(10) << "found optimizer type: " << node->Op()->Type();
for (const std::string& attr_name : node->Op()->AttrNames()) {
auto attr_type = node->Op()->GetAttrType(attr_name);
// with adam, attr are float
if (attr_type == proto::AttrType::FLOAT) {
auto attr_value =
BOOST_GET_CONST(float, node->Op()->GetAttr(attr_name));
ipu_backend->GetExecutor().SetOptimizerAttr(attr_name, attr_value);
} else {
VLOG(10) << "Skip " << attr_type;
}
}
auto lr_var_name = node->Op()->Input("LearningRate");
PADDLE_ENFORCE_EQ(lr_var_name.size(), 1u,
platform::errors::InvalidArgument(
"In op(%s), find input(LearningRate) failed.",
node->Op()->Type()));
ipu_backend->GetExecutor().SetLRVarName(lr_var_name[0]);
}
if ((op_role == static_cast<int>(framework::OpRole::kLoss))) {
VLOG(10) << "found loss op type: " << node->Op()->Type();
auto outputs = node->Op()->Outputs();
PADDLE_ENFORCE_EQ(
outputs.size(), 1,
platform::errors::InvalidArgument("Can only support one loss key"));
auto losses_name = outputs.begin()->second;
PADDLE_ENFORCE_EQ(losses_name.size(), 1,
platform::errors::InvalidArgument(
"Can only support one loss name"));
ipu_backend->GetExecutor().SetLoss(losses_name[0]);
}
}
}
VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
VLOG(10) << "leave IpuOptimizerExtractPass::ApplyImpl";
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(optimizer_extract_pass,
paddle::framework::ir::IpuOptimizerExtractPass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class IpuOptimizerExtractPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/optimizer_state_align_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/common.h"
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
namespace paddle {
namespace framework {
namespace ir {
using paddle::platform::ipu::IpuBackend;
using framework::ir::Graph;
using framework::ir::Node;
void IpuOptimizerStateAlignPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter IpuOptimizerStateAlignPass::ApplyImpl";
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);
auto ipu_backend = IpuBackend::GetInstance();
const auto* scope_ = ipu_backend->GetScope();
for (auto* node : graph->Nodes()) {
if (node->IsOp() && node->Op()) {
int op_role = BOOST_GET_CONST(
int, node->Op()->GetAttr(
framework::OpProtoAndCheckerMaker::OpRoleAttrName()));
if ((op_role == static_cast<int>(framework::OpRole::kOptimize))) {
auto inputs = node->Op()->Inputs();
if (inputs.count(platform::ipu::sBeta1Pow)) {
auto var = scope_->GetVar(inputs.at(platform::ipu::sBeta1Pow)[0]);
auto data = var->GetMutable<framework::LoDTensor>()->data<float>();
auto beta = BOOST_GET_CONST(
float, node->Op()->GetAttr(platform::ipu::sBeta1));
// ensure current save with beta1pow, rather than step.
// beta1pow = beta1 ^ (step + 1). Just set beta1pow because popart
// support single Step__
bool save_with_beta1pow = (data[0] < 1.0f) && (data[0] > 0.0f);
float step = 0;
float beta_acc = beta;
while (beta_acc > data[0] && save_with_beta1pow) {
beta_acc *= beta;
step += 1;
}
if (save_with_beta1pow) {
data[0] = step;
}
}
}
}
}
VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
VLOG(10) << "leave IpuOptimizerStateAlignPass::ApplyImpl";
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(optimizer_state_align_pass,
paddle::framework::ir::IpuOptimizerStateAlignPass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
/*
* This pass should only affect optimizer that need bias correction,
* include Adam/Lamb.
*/
class IpuOptimizerStateAlignPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/post_canonicalization.h"
namespace paddle {
namespace framework {
namespace ir {
using framework::ir::Graph;
using framework::ir::Node;
using platform::ipu::SymbolHandler;
void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
VLOG(10) << "enter PopartCanonicalizationPass::ApplyImpl";
VLOG(10) << "Raw Graph: ";
VLOG(10) << DebugString(graph);
auto nodes = graph->Nodes();
for (auto* node : nodes) {
if (!node->IsOp()) {
continue;
}
auto* op = node->Op();
auto op_type = op->Type();
ir::Node* new_node = nullptr;
SymbolHandler handler = platform::ipu::GetHandler(op_type);
if (handler) {
VLOG(11) << "Raw Paddle Node:";
VLOG(11) << node->Op()->Proto()->DebugString();
new_node = handler(graph, node);
VLOG(11) << "Post Popart Node:";
VLOG(11) << new_node->Op()->Proto()->DebugString();
platform::ipu::ClearNode(node);
graph->RemoveNode(node);
} else {
LOG(ERROR) << "Can not find OpHandler for op_type: " << op_type;
}
}
VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);
VLOG(10) << "leave PopartCanonicalizationPass::ApplyImpl";
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(popart_canonicalization_pass,
paddle::framework::ir::PopartCanonicalizationPass);
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/ipu/ipu_pass_base.h"
namespace paddle {
namespace framework {
namespace ir {
class PopartCanonicalizationPass : public IPUPassBase {
protected:
void ApplyImpl(ir::Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
# IPU
IF(WITH_IPU)
FILE(GLOB POPART_CANONICALIZATION_SRC ${PADDLE_SOURCE_DIR}/paddle/fluid/platform/device/ipu/popart_canonicalization/*.cc)
cc_library(ipu_device SRCS device.cc DEPS enforce popart)
cc_library(ipu_utils SRCS ipu_utils.cc DEPS memory framework_proto popart)
cc_library(ipu_strategy SRCS ipu_strategy.cc DEPS popart graph framework_proto enforce)
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace ipu {
namespace {
Node *activation_op_handler(Graph *graph, Node *node, const std::string &type) {
auto new_node = CreateBaseOp(graph, node, type, {GetInputVarNode("X", node)},
node->outputs);
return new_node;
}
Node *relu_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_relu");
}
Node *tanh_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_tanh");
}
Node *log_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_log");
}
Node *sigmoid_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_sigmoid");
}
Node *sqrt_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_sqrt");
}
Node *gelu_handler(Graph *graph, Node *node) {
return activation_op_handler(graph, node, "popart_gelu_v2");
}
Node *log_softmax_handler(Graph *graph, Node *node) {
auto axis = BOOST_GET_CONST(int, node->Op()->GetAttr("axis"));
auto new_softmax = CreateSoftmaxOpset11(graph, node, node->inputs, {}, axis);
return CreateBaseOp(graph, node, "popart_log", new_softmax->outputs,
node->outputs);
}
REGISTER_HANDLER(relu, relu_handler);
REGISTER_HANDLER(tanh, tanh_handler);
REGISTER_HANDLER(log, log_handler);
REGISTER_HANDLER(sigmoid, sigmoid_handler);
REGISTER_HANDLER(sqrt, sqrt_handler);
REGISTER_HANDLER(gelu, gelu_handler);
REGISTER_HANDLER(log_softmax, log_softmax_handler);
} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
namespace paddle {
namespace platform {
namespace ipu {
// This avoids the static initialisation order fiasco,
std::unordered_map<std::string, SymbolHandler> &SymbolHandlers() {
static std::unordered_map<std::string, SymbolHandler> symbol_handlers;
return symbol_handlers;
}
bool RegisterHandler(const std::string &symbol, const SymbolHandler &handler) {
if (SymbolHandlers().count(symbol) != 0) {
LOG(WARNING) << "Trying to register popart handler twice for operator: "
<< symbol;
return false;
}
bool new_handler = SymbolHandlers().emplace(symbol, handler).second;
return new_handler;
}
// Return a pointer to a handler if one is registered for this kind of node or
// an empty std::function otherwise.
SymbolHandler GetHandler(const std::string &kind) {
auto it = SymbolHandlers().find(kind);
if (it != SymbolHandlers().end()) {
return it->second;
}
return {};
}
void ConnectNodes(Node *first_node, Node *next_node) {
first_node->outputs.push_back(next_node);
next_node->inputs.push_back(first_node);
}
void DisConnectNodes(Node *first_node, Node *next_node) {
auto rm_by_value = [&](std::vector<Node *> &vec, Node *n) {
vec.erase(std::remove(vec.begin(), vec.end(), n), vec.end());
};
rm_by_value(first_node->outputs, next_node);
rm_by_value(next_node->inputs, first_node);
rm_by_value(first_node->inputs, next_node);
rm_by_value(next_node->outputs, first_node);
}
void ClearNode(Node *node) {
auto rm_by_value = [&](std::vector<Node *> &vec, Node *n) {
vec.erase(std::remove(vec.begin(), vec.end(), n), vec.end());
};
for (auto *node_in : node->inputs) {
rm_by_value(node_in->outputs, node);
}
for (auto *node_out : node->outputs) {
rm_by_value(node_out->inputs, node);
}
}
void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op,
bool override) {
if (new_op->HasAttr(attr_name) && !override) {
return;
}
if (op->HasAttr(attr_name)) {
VLOG(10) << "Copying attr: " << attr_name << " from " << op->Type()
<< " to " << new_op->Type();
new_op->SetAttr(attr_name, op->GetAttr(attr_name));
new_op->Flush();
}
}
const int VarType2OnnxDtype(const int type) {
auto dtype = static_cast<framework::proto::VarType::Type>(type);
switch (dtype) {
case framework::proto::VarType::BOOL:
return static_cast<int>(ONNXDataType::BOOL);
case framework::proto::VarType::INT16:
return static_cast<int>(ONNXDataType::INT16);
case framework::proto::VarType::INT32:
return static_cast<int>(ONNXDataType::INT32);
case framework::proto::VarType::INT64:
return static_cast<int>(ONNXDataType::INT64);
case framework::proto::VarType::FP16:
return static_cast<int>(ONNXDataType::FLOAT16);
case framework::proto::VarType::FP32:
return static_cast<int>(ONNXDataType::FLOAT);
case framework::proto::VarType::FP64:
return static_cast<int>(ONNXDataType::DOUBLE);
case framework::proto::VarType::UINT8:
return static_cast<int>(ONNXDataType::UINT8);
case framework::proto::VarType::INT8:
return static_cast<int>(ONNXDataType::INT8);
case framework::proto::VarType::BF16:
return static_cast<int>(ONNXDataType::BFLOAT16);
case framework::proto::VarType::COMPLEX64:
return static_cast<int>(ONNXDataType::COMPLEX64);
case framework::proto::VarType::COMPLEX128:
return static_cast<int>(ONNXDataType::COMPLEX128);
default:
PADDLE_THROW(
platform::errors::Unimplemented("Unsupported data type: %d.", dtype));
}
}
const std::string VarType2PopStr(const int type) {
auto dtype = static_cast<framework::proto::VarType::Type>(type);
switch (dtype) {
case framework::proto::VarType::UINT8:
return "UINT8";
case framework::proto::VarType::INT8:
return "INT8";
case framework::proto::VarType::INT16:
return "INT16";
case framework::proto::VarType::INT32:
return "INT32";
case framework::proto::VarType::INT64:
return "INT64";
case framework::proto::VarType::BOOL:
return "BOOL";
case framework::proto::VarType::FP64:
return "DOUBLE";
case framework::proto::VarType::FP32:
return "FLOAT";
case framework::proto::VarType::FP16:
return "FLOAT16";
default:
PADDLE_THROW(
paddle::platform::errors::Unavailable("Unsupported data type."));
}
}
Node *GetInputVarNode(const std::string &input_name, const Node *op_node,
const int id) {
auto var_name = op_node->Op()->Input(input_name).at(id);
return GetInputVarNodeByVarName(var_name, op_node);
}
Node *GetOutputVarNode(const std::string &output_name, const Node *op_node,
const int id) {
auto var_name = op_node->Op()->Output(output_name).at(id);
return GetOutputVarNodeByVarName(var_name, op_node);
}
Node *GetInputVarNodeByVarName(const std::string &var_name,
const Node *op_node) {
for (auto *var : op_node->inputs) {
if (var->Name() == var_name) {
return var;
}
}
return nullptr;
}
Node *GetOutputVarNodeByVarName(const std::string &var_name,
const Node *op_node) {
for (auto *var : op_node->outputs) {
if (var->Name() == var_name) {
return var;
}
}
return nullptr;
}
const bool is_float_equal(float a, float b, float eps) {
return std::fabs(a - b) <= eps;
}
} // namespace ipu
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/platform/device/ipu/ipu_utils.h"
namespace paddle {
namespace platform {
namespace ipu {
using framework::ir::Graph;
using framework::ir::Node;
using framework::OpDesc;
#define REGISTER_HANDLER(name, func) \
static bool __UNUSED_##name = \
paddle::platform::ipu::RegisterHandler(#name, func)
using SymbolHandler = std::function<Node *(Graph *, Node *)>;
std::unordered_map<std::string, SymbolHandler> &SymbolHandlers();
bool RegisterHandler(const std::string &, const SymbolHandler &);
SymbolHandler GetHandler(const std::string &);
void ConnectNodes(Node *first_node, Node *next_node);
void DisConnectNodes(Node *first_node, Node *next_node);
void ClearNode(Node *node);
void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op,
bool override = false);
const int VarType2OnnxDtype(const int type);
const std::string VarType2PopStr(const int type);
Node *GetInputVarNode(const std::string &input_name, const Node *op_node,
const int id = 0);
Node *GetOutputVarNode(const std::string &output_name, const Node *op_node,
const int id = 0);
Node *GetInputVarNodeByVarName(const std::string &var_name,
const Node *op_node);
Node *GetOutputVarNodeByVarName(const std::string &var_name,
const Node *op_node);
const bool is_float_equal(float a, float b, float eps = 1e-8);
} // namespace ipu
} // namespace platform
} // namespace paddle
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace ipu {
namespace {
Node *elementwise_op_handler(Graph *graph, Node *node,
const std::string &type) {
auto *op = node->Op();
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape();
int64_t x_rank = x_shape.size();
auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape();
int64_t y_rank = y_shape.size();
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
if (axis == -1 || axis == x_rank - 1 || x_rank == y_rank) {
auto new_node =
CreateBaseOp(graph, node, type,
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
node->outputs);
return new_node;
} else {
auto y_new_shape = std::vector<int64_t>(x_rank, 1);
for (int i = axis; i < axis + y_rank; ++i) {
y_new_shape[i] = y_shape[i - axis];
}
auto attrs = AttributeMap{
{"value", y_new_shape},
{"dims", std::vector<int64_t>{x_rank}},
{"dtype", ONNXDataType::INT64},
};
// constant
auto new_node_const = CreateConst(graph, node, {}, {}, attrs);
// reshape
auto new_node_reshape = CreateBaseOp(
graph, node, "popart_reshape",
{GetInputVarNode("Y", node), new_node_const->outputs[0]}, {});
// elementwise_op
auto new_node =
CreateBaseOp(graph, node, type,
{GetInputVarNode("X", node), new_node_reshape->outputs[0]},
node->outputs);
return new_node;
}
}
Node *elementwise_add_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_add");
}
Node *elementwise_sub_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_sub");
}
Node *elementwise_div_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_div");
}
Node *elementwise_mul_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_mul");
}
Node *elementwise_min_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_min");
}
Node *elementwise_max_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_max");
}
Node *elementwise_pow_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_pow");
}
Node *elementwise_mod_handler(Graph *graph, Node *node) {
return elementwise_op_handler(graph, node, "popart_mod");
}
REGISTER_HANDLER(elementwise_add, elementwise_add_handler);
REGISTER_HANDLER(elementwise_sub, elementwise_sub_handler);
REGISTER_HANDLER(elementwise_div, elementwise_div_handler);
REGISTER_HANDLER(elementwise_mul, elementwise_mul_handler);
REGISTER_HANDLER(elementwise_min, elementwise_min_handler);
REGISTER_HANDLER(elementwise_max, elementwise_max_handler);
REGISTER_HANDLER(elementwise_pow, elementwise_pow_handler);
REGISTER_HANDLER(elementwise_mod, elementwise_mod_handler);
} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册