From 0a04b8a9980e3a409642201707c0f1d95be4c5d8 Mon Sep 17 00:00:00 2001 From: Allen Guo Date: Mon, 11 Jul 2022 20:27:40 +0800 Subject: [PATCH] [IPU] support more ops 0/N (#44204) * add authors Co-authored-by: Allen Guo Co-authored-by: Zhixin Yao Co-authored-by: Zhaorui Chen * squash cpp changes 1/N * clean code Co-authored-by: Zhixin Yao Co-authored-by: Zhaorui Chen --- .../ir/ipu/optimizer_extract_pass.cc | 13 - .../ir/ipu/popart_canonicalization_pass.cc | 11 + .../fluid/platform/device/ipu/ipu_compiler.cc | 39 +- .../canonicalization_utils.cc | 53 ++ .../canonicalization_utils.h | 7 + .../ipu/popart_canonicalization/logic_ops.cc | 34 ++ .../ipu/popart_canonicalization/loss_ops.cc | 508 ++++++++++++++++++ .../ipu/popart_canonicalization/math_ops.cc | 287 +++------- .../ipu/popart_canonicalization/op_builder.cc | 34 +- .../ipu/popart_canonicalization/op_builder.h | 6 + .../ipu/popart_canonicalization/other_ops.cc | 12 - .../ipu/popart_canonicalization/reduce_ops.cc | 52 ++ .../device/ipu/supported_ops_autogen.h | 1 + 13 files changed, 795 insertions(+), 262 deletions(-) create mode 100644 paddle/fluid/platform/device/ipu/popart_canonicalization/loss_ops.cc diff --git a/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc b/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc index f28696194e5..b45a39aaa86 100644 --- a/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc +++ b/paddle/fluid/framework/ir/ipu/optimizer_extract_pass.cc @@ -287,19 +287,6 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const { } else if (op_role == OpRole::kLRSched) { // op_role == OpRole::kLRSched | OpRole::kOptimize new_op.SetAttr("with_lr_sched", true); - } else if (op_type == "identity_loss") { - auto outputs = op->Outputs(); - PADDLE_ENFORCE_EQ( - outputs.size(), - 1, - platform::errors::InvalidArgument("Can only support one loss key")); - auto losses = outputs.begin()->second; - PADDLE_ENFORCE_EQ( - losses.size(), - 1, - platform::errors::InvalidArgument("Can only support one loss name")); - auto loss_var = losses.front(); - new_op.SetAttr("loss_var", loss_var); } } diff --git a/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc b/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc index 6806e44f095..222ca619c22 100644 --- a/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc +++ b/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc @@ -30,8 +30,13 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const { auto custom_ops = Get>("custom_ops"); std::vector missing_ops; auto sorted_ops = TopologySortOperations(*graph); + std::unordered_set delete_nodes; for (auto* node : sorted_ops) { auto* op = node->Op(); + if (platform::ipu::IsMarkedForDeletion(node)) { + delete_nodes.insert(node); + continue; + } auto op_type = op->Type(); ir::Node* new_node = nullptr; @@ -67,6 +72,12 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const { "Found unimplemented op_handler(s) for IPU")); } + for (auto* node : delete_nodes) { + // TODO(czr): possible remove + platform::ipu::ClearNode(node); + graph->RemoveNode(node); + } + // post popart_canonicalization VLOG(10) << "Post Graph: "; diff --git a/paddle/fluid/platform/device/ipu/ipu_compiler.cc b/paddle/fluid/platform/device/ipu/ipu_compiler.cc index 930af7e1470..09e68ab5187 100644 --- a/paddle/fluid/platform/device/ipu/ipu_compiler.cc +++ b/paddle/fluid/platform/device/ipu/ipu_compiler.cc @@ -445,6 +445,7 @@ void Compiler::LowerWeights(const Scope* scope) { for (size_t i = 0; i < tensor.dims().size(); ++i) { shape.push_back(tensor.dims().at(i)); } + popart::TensorInfo tensor_info(dtype, shape); popart::ConstVoidData const_data{tensor.data(), tensor_info}; if (!node->outputs.empty()) { @@ -530,21 +531,26 @@ void Compiler::LowerOptimizer(const Scope* scope) { auto raw_type = BOOST_GET_CONST(std::string, op_desc->GetAttr("raw_type")); resources_->optimizer_type = raw_type; - auto loss_var = - BOOST_GET_CONST(std::string, op_desc->GetAttr("loss_var")); - resources_->loss_var = resources_->tensors[loss_var]; resources_->with_lr_sched = BOOST_GET_CONST(bool, op_desc->GetAttr("with_lr_sched")); if (ipu_strategy_->is_dynamic) { + // loss_var in dy2static is set by identity_loss. And lr is + // passed by ipu_strategy. resources_->lr = ipu_strategy_->lr; - } else if (op_desc->HasAttr("lr_var")) { - auto lr_var = BOOST_GET_CONST(std::string, op_desc->GetAttr("lr_var")); - resources_->lr_var = lr_var; - resources_->lr = GetSingleVarFromScope(scope, lr_var); } else { - // adadelta has no lr - resources_->lr = 0.01f; - resources_->with_lr_sched = false; + auto loss_var = + BOOST_GET_CONST(std::string, op_desc->GetAttr("loss_var")); + resources_->loss_var = resources_->tensors[loss_var]; + if (op_desc->HasAttr("lr_var")) { + auto lr_var = + BOOST_GET_CONST(std::string, op_desc->GetAttr("lr_var")); + resources_->lr_var = lr_var; + resources_->lr = GetSingleVarFromScope(scope, lr_var); + } else { + // adadelta has no lr + resources_->lr = 0.01f; + resources_->with_lr_sched = false; + } } VLOG(10) << "Set initial lr: " << resources_->lr; @@ -766,6 +772,19 @@ void Compiler::LowerOptimizer(const Scope* scope) { PADDLE_THROW(platform::errors::Unimplemented( "optimizer %s is not implemented", type)); } + } else if (op_type == "popart_identity_loss") { + auto outputs = op_desc->Outputs(); + PADDLE_ENFORCE_EQ( + outputs.size(), + 1, + platform::errors::InvalidArgument("Can only support one loss key")); + auto losses = outputs.begin()->second; + PADDLE_ENFORCE_EQ( + losses.size(), + 1, + platform::errors::InvalidArgument("Can only support one loss name")); + auto loss_var = losses.front(); + resources_->loss_var = resources_->tensors[loss_var]; } } } diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc index 44fdf764c5b..c4960616b9d 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc @@ -138,6 +138,59 @@ const ONNXDataType GetOutputVarDType(const Node *node, return GetVarDType(out_node); } +bool IsLastVarNode(Node *node) { + return node->IsVar() && node->outputs.size() == 0; +} + +void MarkNodeForDeletion(Node *node) { node->Op()->SetAttr("delete_node", 1); } + +bool IsMarkedForDeletion(Node *node) { + return node->Op()->HasAttr("delete_node") && + BOOST_GET_CONST(int, node->Op()->GetAttr("delete_node")) > 0; +} + +int RemoveTailReduction(Graph *graph, + Node *loss_op, + const std::string &output_var_name) { + // Sum: 0. Mean: 1. None: 2 + int reduction = 2; + Node *reduction_op; + auto loss_output = GetOutputVarNode(output_var_name, loss_op); + for (auto sub_node : loss_output->outputs) { + if (!sub_node->IsOp()) continue; + if (sub_node->Op()->Type() == "reduce_sum") { + reduction = 0; + reduction_op = sub_node; + } else if (sub_node->Op()->Type() == "reduce_mean") { + reduction = 1; + reduction_op = sub_node; + } + } + if (reduction == 2) return reduction; + auto reduction_out = reduction_op->outputs[0]; + loss_op->Op()->SetOutput(output_var_name, + std::vector({reduction_out->Name()})); + MarkNodeForDeletion(reduction_op); + DisConnectNodes(loss_output, reduction_op); + DisConnectNodes(reduction_op, reduction_out); + ConnectNodes(loss_op, reduction_out); + + return reduction; +} + +int ConvertToPopartReduction(const std::string &reduction) { + // Sum: 0. Mean: 1. None: 2 + if (reduction == "sum") { + return 0; + } else if (reduction == "mean") { + return 1; + } else if (reduction == "none") { + return 2; + } + PADDLE_THROW(platform::errors::InvalidArgument( + "reduction %s is not supported on ipu.", reduction)); +} + } // namespace ipu } // namespace platform } // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h index 536b69a39b9..611d863c496 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h @@ -85,6 +85,13 @@ const bool is_float_equal(float a, float b, float eps = 1e-8); const ONNXDataType GetVarDType(const Node *node); const ONNXDataType GetOutputVarDType(const Node *node, const std::string &output_name = "Out"); +void MarkNodeForDeletion(Node *node); +bool IsMarkedForDeletion(Node *node); +bool IsLastVarNode(Node *node); +int RemoveTailReduction(Graph *graph, + Node *loss_op, + const std::string &output_var_name); +int ConvertToPopartReduction(const std::string &reduction); } // namespace ipu } // namespace platform diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc index c10a30997a4..155c11b03b8 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc @@ -91,6 +91,38 @@ Node *less_than_handler(Graph *graph, Node *node) { {}); } +Node *greater_equal_handler(Graph *graph, Node *node) { + auto less_op = + CreateBaseOp(graph, + node, + "popart_less", + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, + {}, + {}); + return CreateBaseOp(graph, + node, + "popart_logical_not", + less_op->outputs, + {GetOutputVarNode("Out", node)}, + {}); +} + +Node *less_equal_handler(Graph *graph, Node *node) { + auto less_op = + CreateBaseOp(graph, + node, + "popart_greater", + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, + {}, + {}); + return CreateBaseOp(graph, + node, + "popart_logical_not", + less_op->outputs, + {GetOutputVarNode("Out", node)}, + {}); +} + } // namespace } // namespace ipu } // namespace platform @@ -103,3 +135,5 @@ REGISTER_HANDLER(logical_or, logical_or_handler); REGISTER_HANDLER(logical_and, logical_and_handler); REGISTER_HANDLER(greater_than, greater_than_handler); REGISTER_HANDLER(less_than, less_than_handler); +REGISTER_HANDLER(greater_equal, greater_equal_handler); +REGISTER_HANDLER(less_equal, less_equal_handler); diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/loss_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/loss_ops.cc new file mode 100644 index 00000000000..438304fcfc7 --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/loss_ops.cc @@ -0,0 +1,508 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/ipu_backend.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +bool is_dynamic_graph() { + auto *ipu_backend = platform::ipu::IpuBackend::GetInstance(); + return ipu_backend->GetIpuStrategy()->is_dynamic; +} + +Node *identity_loss_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto reduction = BOOST_GET_CONST(int, op->GetAttr("reduction")); + return CreateIdentityLossOp( + graph, node, node->inputs, node->outputs, reduction); +} + +Node *cross_entropy_general_handler(Graph *graph, + Node *node, + Node *logits, + Node *label, + Node *output, + bool soft_label, + int ignore_index, + int reduction, + int axis) { + Node *cast_and_reshape = nullptr; + Node *final_loss_node = nullptr; + if (soft_label) { + PADDLE_THROW(platform::errors::InvalidArgument( + "soft_label is not supported yet in IPU")); + } + bool append_identity_loss = is_dynamic_graph(); + bool is_last_var_node = IsLastVarNode(output); + append_identity_loss = append_identity_loss && is_last_var_node; + + if (label->Var()->GetDataType() == framework::proto::VarType::INT32) { + cast_and_reshape = label; + } else { + cast_and_reshape = + CreateCast(graph, node, {label}, {}, framework::proto::VarType::INT32) + ->outputs.front(); + } + + auto label_shape_ = label->Var()->GetShape(); + auto logits_shape_ = logits->Var()->GetShape(); + + axis = axis < 0 ? logits_shape_.size() + axis : axis; + + auto label_transposed(label_shape_); + + if (axis != (logits_shape_.size() - 1)) { + // the softmax axis(a) is not at the last dimension. + // logit shape: [N1, ..., C, ..., Nk] + // label shape: [N1, ..., 1, ..., Nk] + // _____^_____ + // dim: 0, ..., a, ..., k-1 + // needs to transpose the softmax axis in logit to last dimension + // with following transpose perm: [0, ..., a-1, a+1, ..., k-1, a] + std::vector trans(logits_shape_.size(), 0); + std::iota(trans.begin(), trans.begin() + axis, 0); + std::iota(trans.begin() + axis, trans.end() - 1, axis + 1); + trans.back() = axis; + + // transpose logits + logits = + CreateBaseOp( + graph, node, "popart_transpose", {logits}, {}, {{"perm", trans}}) + ->outputs.front(); + + // no need to transpose label, transform the label size and reshape later. + std::transform( + trans.cbegin(), + trans.cend(), + label_transposed.begin(), + [&label_shape_](int64_t index) { return label_shape_[index]; }); + } + + if (label_transposed.back() == 1) { + // input shape: [N1, N2, ... , Nk, C] + // label shape: [N1, N2, ... , Nk, 1] + // reshape label shape to [N1, N2, ... , Nk] + std::vector new_shape_(label_transposed.begin(), + label_transposed.end() - 1); + auto const_before_loss = + CreateBaseOp( + graph, + node, + "popart_constant", + {}, + {}, + {{"value", new_shape_}, + {"dims", + std::vector{static_cast(new_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}) + ->outputs.front(); + + cast_and_reshape = CreateBaseOp(graph, + node, + "popart_reshape", + {cast_and_reshape, const_before_loss}, + {}, + {}) + ->outputs.front(); + } + + auto log = CreateBaseOp(graph, node, "popart_log", {logits}, {}, {}) + ->outputs.front(); + + bool reshape_back = reduction == 2 && label_transposed.back() == 1; + + final_loss_node = CreateBaseOp(graph, + node, + "popart_nllloss_v2", + {log, cast_and_reshape}, + !(reshape_back || append_identity_loss) + ? std::vector{output} + : std::vector{}, + { + {"reduction", reduction}, + {"ignoreIndex", ignore_index}, + {"inputIsLogProbability", true}, + }) + ->outputs.front(); + + if (reshape_back) { + // reshape output to the shape of input label. + auto const_after_loss = + CreateBaseOp( + graph, + node, + "popart_constant", + {}, + {}, + {{"value", label_shape_}, + {"dims", + std::vector{static_cast(label_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}) + ->outputs.front(); + final_loss_node = + CreateBaseOp(graph, + node, + "popart_reshape", + {final_loss_node, const_after_loss}, + append_identity_loss ? std::vector{} + : std::vector{output}, + {}) + ->outputs.front(); + } + + if (append_identity_loss) { + final_loss_node = + CreateIdentityLossOp(graph, node, {final_loss_node}, {output}, 2); + } + + return final_loss_node; +} + +Node *cross_entropy2_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + int reduction = RemoveTailReduction(graph, node, "Y"); + auto logits = GetInputVarNode("X", node); + auto label = GetInputVarNode("Label", node); + auto output = GetOutputVarNode("Y", node); + auto ignore_index = BOOST_GET_CONST(int, op->GetAttr("ignore_index")); + return cross_entropy_general_handler(graph, + node, + logits, + label, + output, + false, /*soft_label*/ + ignore_index, + reduction, + -1); /*axis*/ +} + +Node *softmax_with_cross_entropy_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + int reduction = RemoveTailReduction(graph, node, "Loss"); + auto logits = GetInputVarNode("Logits", node); + auto label = GetInputVarNode("Label", node); + auto output = GetOutputVarNode("Loss", node); + auto ignore_index = BOOST_GET_CONST(int, op->GetAttr("ignore_index")); + auto axis = BOOST_GET_CONST(int, op->GetAttr("axis")); + auto soft_label = BOOST_GET_CONST(bool, op->GetAttr("soft_label")); + + logits = CreateSoftmaxOpset11( + graph, node, {logits}, {GetOutputVarNode("Softmax", node)}, axis) + ->outputs.front(); + return cross_entropy_general_handler(graph, + node, + logits, + label, + output, + soft_label, + ignore_index, + reduction, + axis); +} + +Node *kldiv_loss_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto reduction = ConvertToPopartReduction( + BOOST_GET_CONST(std::string, op->GetAttr("reduction"))); + if (reduction == 2) { + reduction = RemoveTailReduction(graph, node, "Loss"); + } + bool append_identity_loss = is_dynamic_graph(); + bool is_last_var_node = IsLastVarNode(GetOutputVarNode("Loss", node)); + append_identity_loss = append_identity_loss && is_last_var_node; + + // log(pred) + auto log = + CreateBaseOp( + graph, node, "popart_log", {GetInputVarNode("Target", node)}, {}, {}) + ->outputs.front(); + + // log(pred) - label + auto log_minus = + CreateBaseOp( + graph, node, "popart_sub", {log, GetInputVarNode("X", node)}, {}, {}) + ->outputs.front(); + + // label * (log(pred) - label) + auto loss = + CreateBaseOp(graph, + node, + "popart_mul", + {GetInputVarNode("Target", node), log_minus}, + append_identity_loss || reduction != 2 + ? std::vector{} + : std::vector{GetOutputVarNode("Loss", node)}, + {}); + + auto attrs = AttributeMap{{"reduce_all", true}, {"keepdims", 0L}}; + if (append_identity_loss) { + loss = CreateIdentityLossOp(graph, + node, + loss->outputs, + {GetOutputVarNode("Loss", node)}, + reduction); + } else if (reduction == 0) { + // Sum + loss = CreateBaseOp(graph, + node, + "popart_reducesum", + loss->outputs, + {GetOutputVarNode("Loss", node)}, + attrs); + } else if (reduction == 1) { + // Mean + loss = CreateBaseOp(graph, + node, + "popart_reducemean", + loss->outputs, + {GetOutputVarNode("Loss", node)}, + attrs); + } + return loss; +} + +Node *binary_cross_entropy_handler(Graph *graph, Node *node) { + // Out = -1 * weight * (label * log(x) + (1 - label) * log(1 - x)) + int reduction = 2; + if (is_dynamic_graph()) { + reduction = RemoveTailReduction(graph, node, "Out"); + } + bool append_identity_loss = + is_dynamic_graph() && IsLastVarNode(GetOutputVarNode("Loss", node)); + + auto x = GetInputVarNode("X", node); + auto label = GetInputVarNode("Label", node); + // log(x) + auto log = + CreateBaseOp(graph, node, "popart_log", {x}, {}, {})->outputs.front(); + + // label * log(x) + auto log_mul = CreateBaseOp(graph, node, "popart_mul", {label, log}, {}, {}) + ->outputs.front(); + + // const one + auto one = + CreateConst(graph, node, std::vector{1.0}, {1}, GetVarDType(x)) + ->outputs.front(); + // (1 - x) + auto minus_input = CreateBaseOp(graph, node, "popart_sub", {one, x}, {}, {}) + ->outputs.front(); + + // log(1 - x) + auto log_minus_input = + CreateBaseOp(graph, node, "popart_log", {minus_input}, {}, {}) + ->outputs.front(); + + // (1 - label) + auto minus_label = + CreateBaseOp(graph, node, "popart_sub", {one, label}, {}, {}) + ->outputs.front(); + + // (1 - label) * log(1 - x) + auto minus_log_mul = + CreateBaseOp( + graph, node, "popart_mul", {minus_label, log_minus_input}, {}, {}) + ->outputs.front(); + + // (label * log(x) + (1 - label) * log(1 - x)) + auto add = + CreateBaseOp(graph, node, "popart_add", {log_mul, minus_log_mul}, {}, {}) + ->outputs.front(); + + // -1 * (label * log(x) + (1 - label) * log(1 - x)) + auto loss = CreateBaseOp( + graph, + node, + "popart_neg", + {add}, + append_identity_loss ? std::vector{} + : std::vector{GetOutputVarNode("Out", node)}, + {}); + if (append_identity_loss) { + loss = CreateIdentityLossOp( + graph, node, loss->outputs, {GetOutputVarNode("Out", node)}, reduction); + } + return loss; +} + +Node *huber_loss_handler(Graph *graph, Node *node) { + // if abs(label - input) < delta + // huber_loss = 0.5 * (label - input) * (label - input) + // else + // huber_loss = delta * abs(label - input) - 0.5 * delta * delta + auto *op = node->Op(); + int reduction = 2; + if (is_dynamic_graph()) { + reduction = RemoveTailReduction(graph, node, "Out"); + } + bool append_identity_loss = + is_dynamic_graph() && IsLastVarNode(GetOutputVarNode("Out", node)); + + auto x = GetInputVarNode("X", node); + auto label = GetInputVarNode("Y", node); + // (label - input) + auto diff = CreateBaseOp(graph, node, "popart_sub", {label, x}, {}, {}) + ->outputs.front(); + + // abs(label - input) + auto abs_diff = + CreateBaseOp(graph, node, "popart_abs", {diff}, {}, {})->outputs.front(); + + // const 0.5 + auto dot_five = + CreateConst(graph, node, std::vector{0.5}, {1}, GetVarDType(x)) + ->outputs.front(); + + // const delta + auto delta_value = BOOST_GET_CONST(float, op->GetAttr("delta")); + auto delta = + CreateConst( + graph, node, std::vector{delta_value}, {1}, GetVarDType(x)) + ->outputs.front(); + auto delta_square_coff = + CreateConst(graph, + node, + std::vector{0.5f * delta_value * delta_value}, + {1}, + GetVarDType(x)) + ->outputs.front(); + + // (label - input) * (label - input) + auto square = CreateBaseOp(graph, node, "popart_mul", {diff, diff}, {}, {}) + ->outputs.front(); + + // 0.5 * (label - input) * (label - input) + auto dot_five_square = + CreateBaseOp(graph, node, "popart_mul", {dot_five, square}, {}, {}) + ->outputs.front(); + + // delta * abs(label - input) + auto delta_mul_diff = + CreateBaseOp(graph, node, "popart_mul", {delta, abs_diff}, {}, {}) + ->outputs.front(); + + // delta * abs(label - input) - 0.5 * delta * delta + auto sub_delta_square = CreateBaseOp(graph, + node, + "popart_sub", + {delta_mul_diff, delta_square_coff}, + {}, + {}) + ->outputs.front(); + + // abs(label - input) < delta + auto less_cond = + CreateBaseOp(graph, node, "popart_less", {abs_diff, delta}, {}, {}) + ->outputs.front(); + auto loss = CreateBaseOp( + graph, + node, + "popart_where", + {less_cond, dot_five_square, sub_delta_square}, + append_identity_loss ? std::vector{} + : std::vector{GetOutputVarNode("Out", node)}, + {}); + + if (append_identity_loss) { + loss = CreateIdentityLossOp( + graph, node, loss->outputs, {GetOutputVarNode("Out", node)}, reduction); + } + return loss; +} + +Node *warpctc_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto logits = GetInputVarNode("Logits", node); + auto label = GetInputVarNode("Label", node); + auto logits_length = GetInputVarNode("LogitsLength", node); + auto label_length = GetInputVarNode("LabelLength", node); + auto blank = BOOST_GET_CONST(int, op->GetAttr("blank")); + auto norm_by_times = BOOST_GET_CONST(bool, op->GetAttr("norm_by_times")); + int reduction = 2; + if (is_dynamic_graph()) { + reduction = RemoveTailReduction(graph, node, "Loss"); + } + bool append_identity_loss = + is_dynamic_graph() && IsLastVarNode(GetOutputVarNode("Loss", node)); + if (norm_by_times) { + PADDLE_THROW(platform::errors::InvalidArgument( + "norm_by_times is not supported yet in IPU")); + } + + int axis = -1; + auto softmax_logits = + CreateSoftmaxOpset11(graph, node, {logits}, {}, axis)->outputs.front(); + auto log_softmax_logits = + CreateBaseOp(graph, node, "popart_log", {softmax_logits}, {}, {}) + ->outputs.front(); + auto cast_label = CreateBaseOp(graph, + node, + "popart_cast", + {label}, + {}, + {{"to", std::string("UINT32")}}) + ->outputs.front(); + auto cast_logits_length = CreateBaseOp(graph, + node, + "popart_cast", + {logits_length}, + {}, + {{"to", std::string("UINT32")}}) + ->outputs.front(); + auto cast_label_length = CreateBaseOp(graph, + node, + "popart_cast", + {label_length}, + {}, + {{"to", std::string("UINT32")}}) + ->outputs.front(); + // TODO(czr): zero_infinity is not supported in current sdk which lead + // difference with paddle result. + auto loss = CreateBaseOp( + graph, + node, + "popart_ctcloss", + {log_softmax_logits, cast_label, cast_logits_length, cast_label_length}, + append_identity_loss + ? std::vector{} + : std::vector{GetOutputVarNode("Loss", node)}, + {{"blank", blank}, + {"reduction", reduction}, + {"outDataType", std::string("UNDEFINED")}}); + if (append_identity_loss) { + loss = CreateIdentityLossOp( + graph, node, loss->outputs, {GetOutputVarNode("Loss", node)}, 2); + } + return loss; +} + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle + +REGISTER_HANDLER(identity_loss, identity_loss_handler); +REGISTER_HANDLER(softmax_with_cross_entropy, + softmax_with_cross_entropy_handler); +REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler); +REGISTER_HANDLER(kldiv_loss, kldiv_loss_handler); +REGISTER_HANDLER(bce_loss, binary_cross_entropy_handler); +REGISTER_HANDLER(huber_loss, huber_loss_handler); +REGISTER_HANDLER(warpctc, warpctc_handler); diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc index e47a723125b..ddd7d9453cf 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc @@ -114,14 +114,29 @@ Node *matmul_handler(Graph *graph, Node *node) { auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("transpose_X")); auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("transpose_Y")); auto alpha = BOOST_GET_CONST(float, op->GetAttr("alpha")); - auto x_shape = GetInputVarNode("X", node)->Var()->GetShape(); - auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape(); + Node *x_node = GetInputVarNode("X", node); + Node *y_node = GetInputVarNode("Y", node); + int x_rank = x_node->Var()->GetShape().size(); + int y_rank = y_node->Var()->GetShape().size(); + + auto gen_perm = [](const int rank) -> std::vector { + std::vector perm; + if (rank == 1) { + perm = std::vector{0}; + } else if (rank == 2) { + perm = std::vector{1, 0}; + } else if (rank == 3) { + perm = std::vector{0, 2, 1}; + } else if (rank == 4) { + perm = std::vector{0, 1, 3, 2}; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "op matmul with input rank == %d", rank)); + } + return perm; + }; - int x_rank = x_shape.size(); - std::vector perm; - if (x_rank == 1) { - perm = std::vector{0}; - } else if (x_rank == 2) { + if (x_rank == 2) { if (!transpose_x && !transpose_y && is_float_equal(alpha, 1.0f)) { return CreateBaseOp( graph, @@ -137,18 +152,10 @@ Node *matmul_handler(Graph *graph, Node *node) { transpose_x, transpose_y, alpha); - } else if (x_rank == 3) { - perm = std::vector{0, 2, 1}; - } else if (x_rank == 4) { - perm = std::vector{0, 1, 3, 2}; - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "op matmul with input rank == %d", x_rank)); } - Node *x_node = GetInputVarNode("X", node); - Node *y_node = GetInputVarNode("Y", node); if (transpose_x) { + auto perm = gen_perm(x_rank); x_node = CreateBaseOp(graph, node, "popart_transpose", @@ -158,6 +165,7 @@ Node *matmul_handler(Graph *graph, Node *node) { x_node = x_node->outputs[0]; } if (transpose_y) { + auto perm = gen_perm(y_rank); y_node = CreateBaseOp(graph, node, "popart_transpose", @@ -209,7 +217,7 @@ Node *scale_handler(Graph *graph, Node *node) { CreateCast(graph, node, {GetInputVarNode("X", node)}, {}, VarType::FP32); Node *result = nullptr; - if (!op->Input("ScaleTensor").empty()) { + if (op->InputArgumentNames().size() > 1) { auto scale = GetInputVarNode("ScaleTensor", node); if (is_float_equal(bias_, 0.0)) { result = CreateBaseOp( @@ -321,183 +329,6 @@ Node *scale_handler(Graph *graph, Node *node) { return result_after_cast; } -Node *cross_entropy2_handler(Graph *graph, Node *node) { - auto *op = node->Op(); - auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index")); - Node *new_cast = nullptr; - if (GetInputVarNode("Label", node)->Var()->GetDataType() == VarType::INT32) { - new_cast = GetInputVarNode("Label", node); - } else { - auto new_cast = CreateCast( - graph, node, {GetInputVarNode("Label", node)}, {}, VarType::INT32); - new_cast = new_cast->outputs[0]; - } - auto label_shape_ = GetInputVarNode("Label", node)->Var()->GetShape(); - if (label_shape_[label_shape_.size() - 1] != 1) { - auto log = CreateBaseOp( - graph, node, "popart_log", {GetInputVarNode("X", node)}, {}, {}); - return CreateBaseOp( - graph, - node, - "popart_nllloss_v2", - {log->outputs[0], new_cast}, - {GetOutputVarNode("Y", node)}, - { - {"reduction", 2}, // popart::ReductionType::NoReduction - {"ignoreIndex", ignoreIndex}, - {"inputIsLogProbability", true}, - }); - } else { - std::vector new_shape_{label_shape_[0]}; - auto const_before_loss = CreateBaseOp( - graph, - node, - "popart_constant", - {}, - {}, - {{"value", new_shape_}, - {"dims", - std::vector{static_cast(new_shape_.size())}}, - {"dtype", ONNXDataType::INT64}}); - - auto reshape_before_loss = - CreateBaseOp(graph, - node, - "popart_reshape", - {new_cast, const_before_loss->outputs[0]}, - {}, - {}); - - auto log = CreateBaseOp( - graph, node, "popart_log", {GetInputVarNode("X", node)}, {}, {}); - auto nllloss = CreateBaseOp( - graph, - node, - "popart_nllloss_v2", - {log->outputs[0], reshape_before_loss->outputs[0]}, - {}, - { - {"reduction", 2}, // popart::ReductionType::NoReduction - {"ignoreIndex", ignoreIndex}, - {"inputIsLogProbability", true}, - }); - - auto const_after_loss = CreateBaseOp( - graph, - node, - "popart_constant", - {}, - {}, - {{"value", label_shape_}, - {"dims", - std::vector{static_cast(label_shape_.size())}}, - {"dtype", ONNXDataType::INT64}}); - - auto reshape_after_loss = - CreateBaseOp(graph, - node, - "popart_reshape", - {nllloss->outputs[0], const_after_loss->outputs[0]}, - {GetOutputVarNode("Y", node)}, - {}); - return reshape_after_loss; - } -} - -Node *softmax_with_cross_entropy_handler(Graph *graph, Node *node) { - auto *op = node->Op(); - auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index")); - auto axis = BOOST_GET_CONST(int, op->GetAttr("axis")); - auto soft_label = BOOST_GET_CONST(bool, op->GetAttr("soft_label")); - if (soft_label) { - PADDLE_THROW(platform::errors::InvalidArgument( - "soft_label is not supported yet in IPU")); - } - Node *new_cast = nullptr; - if (GetInputVarNode("Label", node)->Var()->GetDataType() == VarType::INT32) { - new_cast = GetInputVarNode("Label", node); - } else { - auto new_cast = CreateCast( - graph, node, {GetInputVarNode("Label", node)}, {}, VarType::INT32); - new_cast = new_cast->outputs[0]; - } - auto softmax_node = CreateSoftmaxOpset11( - graph, node, {GetInputVarNode("Logits", node)}, {}, axis); - - auto label_shape_ = GetInputVarNode("Label", node)->Var()->GetShape(); - if (label_shape_[label_shape_.size() - 1] != 1) { - auto log = CreateBaseOp( - graph, node, "popart_log", {softmax_node->outputs[0]}, {}, {}); - // softmax_with_cross_entropy is split to several ops in python. - // reduction is not needed here. - return CreateBaseOp( - graph, - node, - "popart_nllloss_v2", - {log->outputs[0], new_cast}, - {GetOutputVarNode("Loss", node)}, - { - {"reduction", 2}, // popart::ReductionType::NoReduction - {"ignoreIndex", ignoreIndex}, - {"inputIsLogProbability", true}, - }); - } else { - std::vector new_shape_{label_shape_[0]}; - auto const_before_loss = CreateBaseOp( - graph, - node, - "popart_constant", - {}, - {}, - {{"value", new_shape_}, - {"dims", - std::vector{static_cast(new_shape_.size())}}, - {"dtype", ONNXDataType::INT64}}); - - auto reshape_before_loss = - CreateBaseOp(graph, - node, - "popart_reshape", - {new_cast, const_before_loss->outputs[0]}, - {}, - {}); - - auto log = CreateBaseOp( - graph, node, "popart_log", {softmax_node->outputs[0]}, {}, {}); - auto nllloss = CreateBaseOp( - graph, - node, - "popart_nllloss_v2", - {log->outputs[0], reshape_before_loss->outputs[0]}, - {}, - { - {"reduction", 2}, // popart::ReductionType::NoReduction - {"ignoreIndex", ignoreIndex}, - {"inputIsLogProbability", true}, - }); - - auto const_after_loss = CreateBaseOp( - graph, - node, - "popart_constant", - {}, - {}, - {{"value", label_shape_}, - {"dims", - std::vector{static_cast(label_shape_.size())}}, - {"dtype", ONNXDataType::INT64}}); - - auto reshape_after_loss = - CreateBaseOp(graph, - node, - "popart_reshape", - {nllloss->outputs[0], const_after_loss->outputs[0]}, - {GetOutputVarNode("Loss", node)}, - {}); - return reshape_after_loss; - } -} - Node *cumsum_handler(Graph *graph, Node *node) { auto *op = node->Op(); auto exclusive = BOOST_GET_CONST(bool, op->GetAttr("exclusive")); @@ -512,41 +343,63 @@ Node *cumsum_handler(Graph *graph, Node *node) { {{"value", std::vector{axis}}, {"dims", std::vector{1}}, {"dtype", ONNXDataType::INT64}}); - return CreateBaseOp( + Node *input_x = nullptr; + auto data_type_ = GetInputVarNode("X", node)->Var()->GetDataType(); + bool need_cast = data_type_ != VarType::FP32; + std::vector cumsum_out; + if (need_cast) { + auto cast_x = CreateCast( + graph, node, {GetInputVarNode("X", node)}, {}, VarType::FP32); + input_x = cast_x->outputs[0]; + } else { + input_x = GetInputVarNode("X", node); + cumsum_out.emplace_back(GetOutputVarNode("Out", node)); + } + auto cumsum_node = CreateBaseOp( graph, node, "popart_cumsum", - {GetInputVarNode("X", node), axis_node->outputs[0]}, - {GetOutputVarNode("Out", node)}, + {input_x, axis_node->outputs[0]}, + cumsum_out, {{"exclusive", popart_exclusive}, {"reverse", popart_reverse}}); + if (need_cast) { + cumsum_node = CreateCast(graph, + node, + cumsum_node->outputs, + {GetOutputVarNode("Out", node)}, + data_type_); + } + return cumsum_node; } Node *matmul_v2_handler(Graph *graph, Node *node) { auto *op = node->Op(); auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("trans_x")); auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("trans_y")); - auto x_shape = GetInputVarNode("X", node)->Var()->GetShape(); - auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape(); - - std::vector perm; - int x_rank = x_shape.size(); - if (x_rank == 1) { - perm = std::vector{0}; - } else if (x_rank == 2) { - perm = std::vector{1, 0}; - } else if (x_rank == 3) { - perm = std::vector{0, 2, 1}; - } else if (x_rank == 4) { - perm = std::vector{0, 1, 3, 2}; - } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "op matmul with input rank == %d", x_rank)); - } - Node *x_node = GetInputVarNode("X", node); Node *y_node = GetInputVarNode("Y", node); + int x_rank = x_node->Var()->GetShape().size(); + int y_rank = y_node->Var()->GetShape().size(); + + auto gen_perm = [](const int rank) -> std::vector { + std::vector perm; + if (rank == 1) { + perm = std::vector{0}; + } else if (rank == 2) { + perm = std::vector{1, 0}; + } else if (rank == 3) { + perm = std::vector{0, 2, 1}; + } else if (rank == 4) { + perm = std::vector{0, 1, 3, 2}; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "op matmul with input rank == %d", rank)); + } + return perm; + }; if (transpose_x) { + auto perm = gen_perm(x_rank); x_node = CreateBaseOp(graph, node, "popart_transpose", @@ -556,6 +409,7 @@ Node *matmul_v2_handler(Graph *graph, Node *node) { x_node = x_node->outputs[0]; } if (transpose_y) { + auto perm = gen_perm(y_rank); y_node = CreateBaseOp(graph, node, "popart_transpose", @@ -611,9 +465,6 @@ REGISTER_HANDLER(matmul, matmul_handler); REGISTER_HANDLER(sum, sum_handler); REGISTER_HANDLER(softmax, softmax_handler); REGISTER_HANDLER(scale, scale_handler); -REGISTER_HANDLER(softmax_with_cross_entropy, - softmax_with_cross_entropy_handler); -REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler); REGISTER_HANDLER(cumsum, cumsum_handler); REGISTER_HANDLER(matmul_v2, matmul_v2_handler); REGISTER_HANDLER(bmm, bmm_handler); diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc index 173ea6d4d51..6badf37d5b3 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc @@ -55,9 +55,20 @@ Node *MakeOpNode(Graph *graph, op_desc->SetType(type); auto op = graph->CreateOpNode(op_desc.get()); + // inputs + std::vector input_names; for (auto *in : inputs) { - ConnectNodes(in, op); + if (in != nullptr) { + ConnectNodes(in, op); + input_names.push_back(in->Name()); + } else { + input_names.push_back(std::string("")); + } } + op->Op()->SetInput("__inputs__", input_names); + + // outputs + std::vector output_names; if (outputs.empty()) { auto var = MakeVarNode(graph, node); ConnectNodes(op, var); @@ -66,14 +77,6 @@ Node *MakeOpNode(Graph *graph, ConnectNodes(op, out); } } - - // i/o - std::vector input_names; - for (auto node : op->inputs) { - input_names.push_back(node->Name()); - } - op->Op()->SetInput("__inputs__", input_names); - std::vector output_names; for (auto node : op->outputs) { output_names.push_back(node->Name()); } @@ -138,6 +141,19 @@ Node *CreateCast(Graph *graph, graph, node, "popart_cast", inputs, outputs, {{"to", to}}); } +Node *CreateIdentityLossOp(Graph *graph, + Node *node, + const std::vector &inputs, + const std::vector &outputs, + int reduction) { + return CreateBaseOp(graph, + node, + "popart_identity_loss", + inputs, + outputs, + {{"reduction", reduction}}); +} + Node *CreateGemm(Graph *graph, Node *node, const std::vector &inputs, diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h index 582b506974f..3071c2a0b90 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h @@ -67,6 +67,12 @@ Node *CreateCast(Graph *graph, const std::vector &outputs, const VarType::Type otype); +Node *CreateIdentityLossOp(Graph *graph, + Node *node, + const std::vector &inputs, + const std::vector &outputs, + int reduction); + Node *CreateGemm(Graph *graph, Node *node, const std::vector &inputs, diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc index 1e9291cf572..0b95f641695 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc @@ -85,17 +85,6 @@ Node *identity_handler(Graph *graph, Node *node) { graph, node, "popart_identity", node->inputs, node->outputs); } -Node *identity_loss_handler(Graph *graph, Node *node) { - auto *op = node->Op(); - auto reduction = BOOST_GET_CONST(int, op->GetAttr("reduction")); - return CreateBaseOp(graph, - node, - "popart_identity_loss", - node->inputs, - node->outputs, - {{"reduction", reduction}}); -} - Node *detach_handler(Graph *graph, Node *node) { return CreateBaseOp( graph, node, "popart_detach_v2", node->inputs, node->outputs); @@ -112,5 +101,4 @@ REGISTER_HANDLER(popart_optimizer, popart_optimizer_handler); REGISTER_HANDLER(checkpointoutput, checkpointoutput_handler); REGISTER_HANDLER(custom_nll_loss, custom_nll_loss_handler); REGISTER_HANDLER(identity, identity_handler); -REGISTER_HANDLER(identity_loss, identity_loss_handler); REGISTER_HANDLER(detach, detach_handler); diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/reduce_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/reduce_ops.cc index 852cb180aa7..e1cc2de8bc5 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/reduce_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/reduce_ops.cc @@ -36,6 +36,27 @@ Node *reduce_op_handler(Graph *graph, Node *node, const std::string &op_name) { return CreateBaseOp(graph, node, op_name, node->inputs, node->outputs, attrs); } +Node *reduce_all_op_handler(Graph *graph, + Node *node, + const std::string &op_name) { + auto *op = node->Op(); + auto attrs = AttributeMap{}; + auto reduce_all = BOOST_GET_CONST(bool, op->GetAttr("reduce_all")); + if (!reduce_all) { + auto axes_ = BOOST_GET_CONST(std::vector, op->GetAttr("dim")); + auto axes = std::vector{axes_.begin(), axes_.end()}; + attrs.emplace("axes", axes); + } + auto keepdims_ = BOOST_GET_CONST(bool, op->GetAttr("keep_dim")); + auto keepdims = int64_t{keepdims_}; + attrs.emplace("keepdims", keepdims); + auto int32_x = + CreateCast(graph, node, node->inputs, {}, VarType::INT32)->outputs[0]; + auto reduce_op = CreateBaseOp(graph, node, op_name, {int32_x}, {}, attrs); + return CreateCast( + graph, node, reduce_op->outputs, node->outputs, VarType::BOOL); +} + Node *reduce_mean_handler(Graph *graph, Node *node) { return reduce_op_handler(graph, node, "popart_reducemean"); } @@ -56,6 +77,34 @@ Node *reduce_prod_handler(Graph *graph, Node *node) { return reduce_op_handler(graph, node, "popart_reduceprod"); } +Node *logsumexp_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto attrs = AttributeMap{}; + auto reduce_all = BOOST_GET_CONST(bool, op->GetAttr("reduce_all")); + if (!reduce_all) { + auto axes_ = BOOST_GET_CONST(std::vector, op->GetAttr("axis")); + auto axes = std::vector{axes_.begin(), axes_.end()}; + attrs.emplace("axes", axes); + } + auto keepdims_ = BOOST_GET_CONST(bool, op->GetAttr("keepdim")); + auto keepdims = int64_t{keepdims_}; + attrs.emplace("keepdims", keepdims); + return CreateBaseOp(graph, + node, + "popart_reducelogsumexp", + node->inputs, + node->outputs, + attrs); +} + +Node *reduce_all_handler(Graph *graph, Node *node) { + return reduce_all_op_handler(graph, node, "popart_reducemin"); +} + +Node *reduce_any_handler(Graph *graph, Node *node) { + return reduce_all_op_handler(graph, node, "popart_reducemax"); +} + } // namespace } // namespace ipu } // namespace platform @@ -66,3 +115,6 @@ REGISTER_HANDLER(reduce_min, reduce_min_handler); REGISTER_HANDLER(reduce_sum, reduce_sum_handler); REGISTER_HANDLER(reduce_max, reduce_max_handler); REGISTER_HANDLER(reduce_prod, reduce_prod_handler); +REGISTER_HANDLER(logsumexp, logsumexp_handler); +REGISTER_HANDLER(reduce_all, reduce_all_handler); +REGISTER_HANDLER(reduce_any, reduce_any_handler); diff --git a/paddle/fluid/platform/device/ipu/supported_ops_autogen.h b/paddle/fluid/platform/device/ipu/supported_ops_autogen.h index 763c5a46abe..14dcf65afee 100644 --- a/paddle/fluid/platform/device/ipu/supported_ops_autogen.h +++ b/paddle/fluid/platform/device/ipu/supported_ops_autogen.h @@ -33,6 +33,7 @@ OP_DECL(popart_dynamicadd_v2, aiGraphcoreOpset.dynamicadd, ARG(INT_VEC,axes) ARG OP_DECL(popart_sequenceslice_v2, aiGraphcoreOpset.sequenceslice, ARG(INT,zeroUnused) ) // NOLINT OP_DECL(popart_replicatedallreduce_v2, aiGraphcoreOpset.replicatedallreduce, OPT_ARG(INT_VEC,commGroup) ) // NOLINT OP_DECL(popart_ctcbeamsearchdecoder_v2, aiGraphcoreOpset.ctcbeamsearchdecoder, ARG(INT,blank) ARG(INT,beamWidth) ARG(INT,topPaths) ) // NOLINT +OP_DECL(popart_ctcloss, aiGraphcoreOpset.ctcloss, SIG_ARG(INT32,popart::ReductionType,reduction) ARG(INT32,blank) ARG(STRING,outDataType) ) // NOLINT OP_DECL(popart_shapeddropout_v2, aiGraphcoreOpset.shapeddropout, ARG(INT_VEC,shape) ARG(FLOAT,ratio) ) // NOLINT OP_DECL(popart_atan2_v2, aiGraphcoreOpset.atan2, NONE) // NOLINT OP_DECL(popart_expm1_v2, aiGraphcoreOpset.expm1, NONE) // NOLINT -- GitLab