未验证 提交 0a04b8a9 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] support more ops 0/N (#44204)

* add authors
Co-authored-by: NAllen Guo <alleng@graphcore.ai>
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>

* squash cpp changes 1/N

* clean code
Co-authored-by: NZhixin Yao <zhixiny@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>
上级 dd63e5b4
...@@ -287,19 +287,6 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const { ...@@ -287,19 +287,6 @@ void IpuOptimizerExtractPass::ApplyImpl(ir::Graph* graph) const {
} else if (op_role == OpRole::kLRSched) { } else if (op_role == OpRole::kLRSched) {
// op_role == OpRole::kLRSched | OpRole::kOptimize // op_role == OpRole::kLRSched | OpRole::kOptimize
new_op.SetAttr("with_lr_sched", true); new_op.SetAttr("with_lr_sched", true);
} else if (op_type == "identity_loss") {
auto outputs = op->Outputs();
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
platform::errors::InvalidArgument("Can only support one loss key"));
auto losses = outputs.begin()->second;
PADDLE_ENFORCE_EQ(
losses.size(),
1,
platform::errors::InvalidArgument("Can only support one loss name"));
auto loss_var = losses.front();
new_op.SetAttr("loss_var", loss_var);
} }
} }
......
...@@ -30,8 +30,13 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const { ...@@ -30,8 +30,13 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
auto custom_ops = Get<std::unordered_set<std::string>>("custom_ops"); auto custom_ops = Get<std::unordered_set<std::string>>("custom_ops");
std::vector<std::string> missing_ops; std::vector<std::string> missing_ops;
auto sorted_ops = TopologySortOperations(*graph); auto sorted_ops = TopologySortOperations(*graph);
std::unordered_set<Node*> delete_nodes;
for (auto* node : sorted_ops) { for (auto* node : sorted_ops) {
auto* op = node->Op(); auto* op = node->Op();
if (platform::ipu::IsMarkedForDeletion(node)) {
delete_nodes.insert(node);
continue;
}
auto op_type = op->Type(); auto op_type = op->Type();
ir::Node* new_node = nullptr; ir::Node* new_node = nullptr;
...@@ -67,6 +72,12 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const { ...@@ -67,6 +72,12 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
"Found unimplemented op_handler(s) for IPU")); "Found unimplemented op_handler(s) for IPU"));
} }
for (auto* node : delete_nodes) {
// TODO(czr): possible remove
platform::ipu::ClearNode(node);
graph->RemoveNode(node);
}
// post popart_canonicalization // post popart_canonicalization
VLOG(10) << "Post Graph: "; VLOG(10) << "Post Graph: ";
......
...@@ -445,6 +445,7 @@ void Compiler::LowerWeights(const Scope* scope) { ...@@ -445,6 +445,7 @@ void Compiler::LowerWeights(const Scope* scope) {
for (size_t i = 0; i < tensor.dims().size(); ++i) { for (size_t i = 0; i < tensor.dims().size(); ++i) {
shape.push_back(tensor.dims().at(i)); shape.push_back(tensor.dims().at(i));
} }
popart::TensorInfo tensor_info(dtype, shape); popart::TensorInfo tensor_info(dtype, shape);
popart::ConstVoidData const_data{tensor.data(), tensor_info}; popart::ConstVoidData const_data{tensor.data(), tensor_info};
if (!node->outputs.empty()) { if (!node->outputs.empty()) {
...@@ -530,15 +531,19 @@ void Compiler::LowerOptimizer(const Scope* scope) { ...@@ -530,15 +531,19 @@ void Compiler::LowerOptimizer(const Scope* scope) {
auto raw_type = auto raw_type =
BOOST_GET_CONST(std::string, op_desc->GetAttr("raw_type")); BOOST_GET_CONST(std::string, op_desc->GetAttr("raw_type"));
resources_->optimizer_type = raw_type; resources_->optimizer_type = raw_type;
auto loss_var =
BOOST_GET_CONST(std::string, op_desc->GetAttr("loss_var"));
resources_->loss_var = resources_->tensors[loss_var];
resources_->with_lr_sched = resources_->with_lr_sched =
BOOST_GET_CONST(bool, op_desc->GetAttr("with_lr_sched")); BOOST_GET_CONST(bool, op_desc->GetAttr("with_lr_sched"));
if (ipu_strategy_->is_dynamic) { if (ipu_strategy_->is_dynamic) {
// loss_var in dy2static is set by identity_loss. And lr is
// passed by ipu_strategy.
resources_->lr = ipu_strategy_->lr; resources_->lr = ipu_strategy_->lr;
} else if (op_desc->HasAttr("lr_var")) { } else {
auto lr_var = BOOST_GET_CONST(std::string, op_desc->GetAttr("lr_var")); auto loss_var =
BOOST_GET_CONST(std::string, op_desc->GetAttr("loss_var"));
resources_->loss_var = resources_->tensors[loss_var];
if (op_desc->HasAttr("lr_var")) {
auto lr_var =
BOOST_GET_CONST(std::string, op_desc->GetAttr("lr_var"));
resources_->lr_var = lr_var; resources_->lr_var = lr_var;
resources_->lr = GetSingleVarFromScope<float>(scope, lr_var); resources_->lr = GetSingleVarFromScope<float>(scope, lr_var);
} else { } else {
...@@ -546,6 +551,7 @@ void Compiler::LowerOptimizer(const Scope* scope) { ...@@ -546,6 +551,7 @@ void Compiler::LowerOptimizer(const Scope* scope) {
resources_->lr = 0.01f; resources_->lr = 0.01f;
resources_->with_lr_sched = false; resources_->with_lr_sched = false;
} }
}
VLOG(10) << "Set initial lr: " << resources_->lr; VLOG(10) << "Set initial lr: " << resources_->lr;
// Get the type of optimizer // Get the type of optimizer
...@@ -766,6 +772,19 @@ void Compiler::LowerOptimizer(const Scope* scope) { ...@@ -766,6 +772,19 @@ void Compiler::LowerOptimizer(const Scope* scope) {
PADDLE_THROW(platform::errors::Unimplemented( PADDLE_THROW(platform::errors::Unimplemented(
"optimizer %s is not implemented", type)); "optimizer %s is not implemented", type));
} }
} else if (op_type == "popart_identity_loss") {
auto outputs = op_desc->Outputs();
PADDLE_ENFORCE_EQ(
outputs.size(),
1,
platform::errors::InvalidArgument("Can only support one loss key"));
auto losses = outputs.begin()->second;
PADDLE_ENFORCE_EQ(
losses.size(),
1,
platform::errors::InvalidArgument("Can only support one loss name"));
auto loss_var = losses.front();
resources_->loss_var = resources_->tensors[loss_var];
} }
} }
} }
......
...@@ -138,6 +138,59 @@ const ONNXDataType GetOutputVarDType(const Node *node, ...@@ -138,6 +138,59 @@ const ONNXDataType GetOutputVarDType(const Node *node,
return GetVarDType(out_node); return GetVarDType(out_node);
} }
bool IsLastVarNode(Node *node) {
return node->IsVar() && node->outputs.size() == 0;
}
void MarkNodeForDeletion(Node *node) { node->Op()->SetAttr("delete_node", 1); }
bool IsMarkedForDeletion(Node *node) {
return node->Op()->HasAttr("delete_node") &&
BOOST_GET_CONST(int, node->Op()->GetAttr("delete_node")) > 0;
}
int RemoveTailReduction(Graph *graph,
Node *loss_op,
const std::string &output_var_name) {
// Sum: 0. Mean: 1. None: 2
int reduction = 2;
Node *reduction_op;
auto loss_output = GetOutputVarNode(output_var_name, loss_op);
for (auto sub_node : loss_output->outputs) {
if (!sub_node->IsOp()) continue;
if (sub_node->Op()->Type() == "reduce_sum") {
reduction = 0;
reduction_op = sub_node;
} else if (sub_node->Op()->Type() == "reduce_mean") {
reduction = 1;
reduction_op = sub_node;
}
}
if (reduction == 2) return reduction;
auto reduction_out = reduction_op->outputs[0];
loss_op->Op()->SetOutput(output_var_name,
std::vector<std::string>({reduction_out->Name()}));
MarkNodeForDeletion(reduction_op);
DisConnectNodes(loss_output, reduction_op);
DisConnectNodes(reduction_op, reduction_out);
ConnectNodes(loss_op, reduction_out);
return reduction;
}
int ConvertToPopartReduction(const std::string &reduction) {
// Sum: 0. Mean: 1. None: 2
if (reduction == "sum") {
return 0;
} else if (reduction == "mean") {
return 1;
} else if (reduction == "none") {
return 2;
}
PADDLE_THROW(platform::errors::InvalidArgument(
"reduction %s is not supported on ipu.", reduction));
}
} // namespace ipu } // namespace ipu
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -85,6 +85,13 @@ const bool is_float_equal(float a, float b, float eps = 1e-8); ...@@ -85,6 +85,13 @@ const bool is_float_equal(float a, float b, float eps = 1e-8);
const ONNXDataType GetVarDType(const Node *node); const ONNXDataType GetVarDType(const Node *node);
const ONNXDataType GetOutputVarDType(const Node *node, const ONNXDataType GetOutputVarDType(const Node *node,
const std::string &output_name = "Out"); const std::string &output_name = "Out");
void MarkNodeForDeletion(Node *node);
bool IsMarkedForDeletion(Node *node);
bool IsLastVarNode(Node *node);
int RemoveTailReduction(Graph *graph,
Node *loss_op,
const std::string &output_var_name);
int ConvertToPopartReduction(const std::string &reduction);
} // namespace ipu } // namespace ipu
} // namespace platform } // namespace platform
......
...@@ -91,6 +91,38 @@ Node *less_than_handler(Graph *graph, Node *node) { ...@@ -91,6 +91,38 @@ Node *less_than_handler(Graph *graph, Node *node) {
{}); {});
} }
Node *greater_equal_handler(Graph *graph, Node *node) {
auto less_op =
CreateBaseOp(graph,
node,
"popart_less",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{},
{});
return CreateBaseOp(graph,
node,
"popart_logical_not",
less_op->outputs,
{GetOutputVarNode("Out", node)},
{});
}
Node *less_equal_handler(Graph *graph, Node *node) {
auto less_op =
CreateBaseOp(graph,
node,
"popart_greater",
{GetInputVarNode("X", node), GetInputVarNode("Y", node)},
{},
{});
return CreateBaseOp(graph,
node,
"popart_logical_not",
less_op->outputs,
{GetOutputVarNode("Out", node)},
{});
}
} // namespace } // namespace
} // namespace ipu } // namespace ipu
} // namespace platform } // namespace platform
...@@ -103,3 +135,5 @@ REGISTER_HANDLER(logical_or, logical_or_handler); ...@@ -103,3 +135,5 @@ REGISTER_HANDLER(logical_or, logical_or_handler);
REGISTER_HANDLER(logical_and, logical_and_handler); REGISTER_HANDLER(logical_and, logical_and_handler);
REGISTER_HANDLER(greater_than, greater_than_handler); REGISTER_HANDLER(greater_than, greater_than_handler);
REGISTER_HANDLER(less_than, less_than_handler); REGISTER_HANDLER(less_than, less_than_handler);
REGISTER_HANDLER(greater_equal, greater_equal_handler);
REGISTER_HANDLER(less_equal, less_equal_handler);
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/platform/device/ipu/ipu_backend.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h"
#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace platform {
namespace ipu {
namespace {
bool is_dynamic_graph() {
auto *ipu_backend = platform::ipu::IpuBackend::GetInstance();
return ipu_backend->GetIpuStrategy()->is_dynamic;
}
Node *identity_loss_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto reduction = BOOST_GET_CONST(int, op->GetAttr("reduction"));
return CreateIdentityLossOp(
graph, node, node->inputs, node->outputs, reduction);
}
Node *cross_entropy_general_handler(Graph *graph,
Node *node,
Node *logits,
Node *label,
Node *output,
bool soft_label,
int ignore_index,
int reduction,
int axis) {
Node *cast_and_reshape = nullptr;
Node *final_loss_node = nullptr;
if (soft_label) {
PADDLE_THROW(platform::errors::InvalidArgument(
"soft_label is not supported yet in IPU"));
}
bool append_identity_loss = is_dynamic_graph();
bool is_last_var_node = IsLastVarNode(output);
append_identity_loss = append_identity_loss && is_last_var_node;
if (label->Var()->GetDataType() == framework::proto::VarType::INT32) {
cast_and_reshape = label;
} else {
cast_and_reshape =
CreateCast(graph, node, {label}, {}, framework::proto::VarType::INT32)
->outputs.front();
}
auto label_shape_ = label->Var()->GetShape();
auto logits_shape_ = logits->Var()->GetShape();
axis = axis < 0 ? logits_shape_.size() + axis : axis;
auto label_transposed(label_shape_);
if (axis != (logits_shape_.size() - 1)) {
// the softmax axis(a) is not at the last dimension.
// logit shape: [N1, ..., C, ..., Nk]
// label shape: [N1, ..., 1, ..., Nk]
// _____^_____
// dim: 0, ..., a, ..., k-1
// needs to transpose the softmax axis in logit to last dimension
// with following transpose perm: [0, ..., a-1, a+1, ..., k-1, a]
std::vector<int64_t> trans(logits_shape_.size(), 0);
std::iota(trans.begin(), trans.begin() + axis, 0);
std::iota(trans.begin() + axis, trans.end() - 1, axis + 1);
trans.back() = axis;
// transpose logits
logits =
CreateBaseOp(
graph, node, "popart_transpose", {logits}, {}, {{"perm", trans}})
->outputs.front();
// no need to transpose label, transform the label size and reshape later.
std::transform(
trans.cbegin(),
trans.cend(),
label_transposed.begin(),
[&label_shape_](int64_t index) { return label_shape_[index]; });
}
if (label_transposed.back() == 1) {
// input shape: [N1, N2, ... , Nk, C]
// label shape: [N1, N2, ... , Nk, 1]
// reshape label shape to [N1, N2, ... , Nk]
std::vector<int64_t> new_shape_(label_transposed.begin(),
label_transposed.end() - 1);
auto const_before_loss =
CreateBaseOp(
graph,
node,
"popart_constant",
{},
{},
{{"value", new_shape_},
{"dims",
std::vector<int64_t>{static_cast<int64_t>(new_shape_.size())}},
{"dtype", ONNXDataType::INT64}})
->outputs.front();
cast_and_reshape = CreateBaseOp(graph,
node,
"popart_reshape",
{cast_and_reshape, const_before_loss},
{},
{})
->outputs.front();
}
auto log = CreateBaseOp(graph, node, "popart_log", {logits}, {}, {})
->outputs.front();
bool reshape_back = reduction == 2 && label_transposed.back() == 1;
final_loss_node = CreateBaseOp(graph,
node,
"popart_nllloss_v2",
{log, cast_and_reshape},
!(reshape_back || append_identity_loss)
? std::vector<Node *>{output}
: std::vector<Node *>{},
{
{"reduction", reduction},
{"ignoreIndex", ignore_index},
{"inputIsLogProbability", true},
})
->outputs.front();
if (reshape_back) {
// reshape output to the shape of input label.
auto const_after_loss =
CreateBaseOp(
graph,
node,
"popart_constant",
{},
{},
{{"value", label_shape_},
{"dims",
std::vector<int64_t>{static_cast<int64_t>(label_shape_.size())}},
{"dtype", ONNXDataType::INT64}})
->outputs.front();
final_loss_node =
CreateBaseOp(graph,
node,
"popart_reshape",
{final_loss_node, const_after_loss},
append_identity_loss ? std::vector<Node *>{}
: std::vector<Node *>{output},
{})
->outputs.front();
}
if (append_identity_loss) {
final_loss_node =
CreateIdentityLossOp(graph, node, {final_loss_node}, {output}, 2);
}
return final_loss_node;
}
Node *cross_entropy2_handler(Graph *graph, Node *node) {
auto *op = node->Op();
int reduction = RemoveTailReduction(graph, node, "Y");
auto logits = GetInputVarNode("X", node);
auto label = GetInputVarNode("Label", node);
auto output = GetOutputVarNode("Y", node);
auto ignore_index = BOOST_GET_CONST(int, op->GetAttr("ignore_index"));
return cross_entropy_general_handler(graph,
node,
logits,
label,
output,
false, /*soft_label*/
ignore_index,
reduction,
-1); /*axis*/
}
Node *softmax_with_cross_entropy_handler(Graph *graph, Node *node) {
auto *op = node->Op();
int reduction = RemoveTailReduction(graph, node, "Loss");
auto logits = GetInputVarNode("Logits", node);
auto label = GetInputVarNode("Label", node);
auto output = GetOutputVarNode("Loss", node);
auto ignore_index = BOOST_GET_CONST(int, op->GetAttr("ignore_index"));
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
auto soft_label = BOOST_GET_CONST(bool, op->GetAttr("soft_label"));
logits = CreateSoftmaxOpset11(
graph, node, {logits}, {GetOutputVarNode("Softmax", node)}, axis)
->outputs.front();
return cross_entropy_general_handler(graph,
node,
logits,
label,
output,
soft_label,
ignore_index,
reduction,
axis);
}
Node *kldiv_loss_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto reduction = ConvertToPopartReduction(
BOOST_GET_CONST(std::string, op->GetAttr("reduction")));
if (reduction == 2) {
reduction = RemoveTailReduction(graph, node, "Loss");
}
bool append_identity_loss = is_dynamic_graph();
bool is_last_var_node = IsLastVarNode(GetOutputVarNode("Loss", node));
append_identity_loss = append_identity_loss && is_last_var_node;
// log(pred)
auto log =
CreateBaseOp(
graph, node, "popart_log", {GetInputVarNode("Target", node)}, {}, {})
->outputs.front();
// log(pred) - label
auto log_minus =
CreateBaseOp(
graph, node, "popart_sub", {log, GetInputVarNode("X", node)}, {}, {})
->outputs.front();
// label * (log(pred) - label)
auto loss =
CreateBaseOp(graph,
node,
"popart_mul",
{GetInputVarNode("Target", node), log_minus},
append_identity_loss || reduction != 2
? std::vector<Node *>{}
: std::vector<Node *>{GetOutputVarNode("Loss", node)},
{});
auto attrs = AttributeMap{{"reduce_all", true}, {"keepdims", 0L}};
if (append_identity_loss) {
loss = CreateIdentityLossOp(graph,
node,
loss->outputs,
{GetOutputVarNode("Loss", node)},
reduction);
} else if (reduction == 0) {
// Sum
loss = CreateBaseOp(graph,
node,
"popart_reducesum",
loss->outputs,
{GetOutputVarNode("Loss", node)},
attrs);
} else if (reduction == 1) {
// Mean
loss = CreateBaseOp(graph,
node,
"popart_reducemean",
loss->outputs,
{GetOutputVarNode("Loss", node)},
attrs);
}
return loss;
}
Node *binary_cross_entropy_handler(Graph *graph, Node *node) {
// Out = -1 * weight * (label * log(x) + (1 - label) * log(1 - x))
int reduction = 2;
if (is_dynamic_graph()) {
reduction = RemoveTailReduction(graph, node, "Out");
}
bool append_identity_loss =
is_dynamic_graph() && IsLastVarNode(GetOutputVarNode("Loss", node));
auto x = GetInputVarNode("X", node);
auto label = GetInputVarNode("Label", node);
// log(x)
auto log =
CreateBaseOp(graph, node, "popart_log", {x}, {}, {})->outputs.front();
// label * log(x)
auto log_mul = CreateBaseOp(graph, node, "popart_mul", {label, log}, {}, {})
->outputs.front();
// const one
auto one =
CreateConst(graph, node, std::vector<float>{1.0}, {1}, GetVarDType(x))
->outputs.front();
// (1 - x)
auto minus_input = CreateBaseOp(graph, node, "popart_sub", {one, x}, {}, {})
->outputs.front();
// log(1 - x)
auto log_minus_input =
CreateBaseOp(graph, node, "popart_log", {minus_input}, {}, {})
->outputs.front();
// (1 - label)
auto minus_label =
CreateBaseOp(graph, node, "popart_sub", {one, label}, {}, {})
->outputs.front();
// (1 - label) * log(1 - x)
auto minus_log_mul =
CreateBaseOp(
graph, node, "popart_mul", {minus_label, log_minus_input}, {}, {})
->outputs.front();
// (label * log(x) + (1 - label) * log(1 - x))
auto add =
CreateBaseOp(graph, node, "popart_add", {log_mul, minus_log_mul}, {}, {})
->outputs.front();
// -1 * (label * log(x) + (1 - label) * log(1 - x))
auto loss = CreateBaseOp(
graph,
node,
"popart_neg",
{add},
append_identity_loss ? std::vector<Node *>{}
: std::vector<Node *>{GetOutputVarNode("Out", node)},
{});
if (append_identity_loss) {
loss = CreateIdentityLossOp(
graph, node, loss->outputs, {GetOutputVarNode("Out", node)}, reduction);
}
return loss;
}
Node *huber_loss_handler(Graph *graph, Node *node) {
// if abs(label - input) < delta
// huber_loss = 0.5 * (label - input) * (label - input)
// else
// huber_loss = delta * abs(label - input) - 0.5 * delta * delta
auto *op = node->Op();
int reduction = 2;
if (is_dynamic_graph()) {
reduction = RemoveTailReduction(graph, node, "Out");
}
bool append_identity_loss =
is_dynamic_graph() && IsLastVarNode(GetOutputVarNode("Out", node));
auto x = GetInputVarNode("X", node);
auto label = GetInputVarNode("Y", node);
// (label - input)
auto diff = CreateBaseOp(graph, node, "popart_sub", {label, x}, {}, {})
->outputs.front();
// abs(label - input)
auto abs_diff =
CreateBaseOp(graph, node, "popart_abs", {diff}, {}, {})->outputs.front();
// const 0.5
auto dot_five =
CreateConst(graph, node, std::vector<float>{0.5}, {1}, GetVarDType(x))
->outputs.front();
// const delta
auto delta_value = BOOST_GET_CONST(float, op->GetAttr("delta"));
auto delta =
CreateConst(
graph, node, std::vector<float>{delta_value}, {1}, GetVarDType(x))
->outputs.front();
auto delta_square_coff =
CreateConst(graph,
node,
std::vector<float>{0.5f * delta_value * delta_value},
{1},
GetVarDType(x))
->outputs.front();
// (label - input) * (label - input)
auto square = CreateBaseOp(graph, node, "popart_mul", {diff, diff}, {}, {})
->outputs.front();
// 0.5 * (label - input) * (label - input)
auto dot_five_square =
CreateBaseOp(graph, node, "popart_mul", {dot_five, square}, {}, {})
->outputs.front();
// delta * abs(label - input)
auto delta_mul_diff =
CreateBaseOp(graph, node, "popart_mul", {delta, abs_diff}, {}, {})
->outputs.front();
// delta * abs(label - input) - 0.5 * delta * delta
auto sub_delta_square = CreateBaseOp(graph,
node,
"popart_sub",
{delta_mul_diff, delta_square_coff},
{},
{})
->outputs.front();
// abs(label - input) < delta
auto less_cond =
CreateBaseOp(graph, node, "popart_less", {abs_diff, delta}, {}, {})
->outputs.front();
auto loss = CreateBaseOp(
graph,
node,
"popart_where",
{less_cond, dot_five_square, sub_delta_square},
append_identity_loss ? std::vector<Node *>{}
: std::vector<Node *>{GetOutputVarNode("Out", node)},
{});
if (append_identity_loss) {
loss = CreateIdentityLossOp(
graph, node, loss->outputs, {GetOutputVarNode("Out", node)}, reduction);
}
return loss;
}
Node *warpctc_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto logits = GetInputVarNode("Logits", node);
auto label = GetInputVarNode("Label", node);
auto logits_length = GetInputVarNode("LogitsLength", node);
auto label_length = GetInputVarNode("LabelLength", node);
auto blank = BOOST_GET_CONST(int, op->GetAttr("blank"));
auto norm_by_times = BOOST_GET_CONST(bool, op->GetAttr("norm_by_times"));
int reduction = 2;
if (is_dynamic_graph()) {
reduction = RemoveTailReduction(graph, node, "Loss");
}
bool append_identity_loss =
is_dynamic_graph() && IsLastVarNode(GetOutputVarNode("Loss", node));
if (norm_by_times) {
PADDLE_THROW(platform::errors::InvalidArgument(
"norm_by_times is not supported yet in IPU"));
}
int axis = -1;
auto softmax_logits =
CreateSoftmaxOpset11(graph, node, {logits}, {}, axis)->outputs.front();
auto log_softmax_logits =
CreateBaseOp(graph, node, "popart_log", {softmax_logits}, {}, {})
->outputs.front();
auto cast_label = CreateBaseOp(graph,
node,
"popart_cast",
{label},
{},
{{"to", std::string("UINT32")}})
->outputs.front();
auto cast_logits_length = CreateBaseOp(graph,
node,
"popart_cast",
{logits_length},
{},
{{"to", std::string("UINT32")}})
->outputs.front();
auto cast_label_length = CreateBaseOp(graph,
node,
"popart_cast",
{label_length},
{},
{{"to", std::string("UINT32")}})
->outputs.front();
// TODO(czr): zero_infinity is not supported in current sdk which lead
// difference with paddle result.
auto loss = CreateBaseOp(
graph,
node,
"popart_ctcloss",
{log_softmax_logits, cast_label, cast_logits_length, cast_label_length},
append_identity_loss
? std::vector<Node *>{}
: std::vector<Node *>{GetOutputVarNode("Loss", node)},
{{"blank", blank},
{"reduction", reduction},
{"outDataType", std::string("UNDEFINED")}});
if (append_identity_loss) {
loss = CreateIdentityLossOp(
graph, node, loss->outputs, {GetOutputVarNode("Loss", node)}, 2);
}
return loss;
}
} // namespace
} // namespace ipu
} // namespace platform
} // namespace paddle
REGISTER_HANDLER(identity_loss, identity_loss_handler);
REGISTER_HANDLER(softmax_with_cross_entropy,
softmax_with_cross_entropy_handler);
REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler);
REGISTER_HANDLER(kldiv_loss, kldiv_loss_handler);
REGISTER_HANDLER(bce_loss, binary_cross_entropy_handler);
REGISTER_HANDLER(huber_loss, huber_loss_handler);
REGISTER_HANDLER(warpctc, warpctc_handler);
...@@ -114,14 +114,29 @@ Node *matmul_handler(Graph *graph, Node *node) { ...@@ -114,14 +114,29 @@ Node *matmul_handler(Graph *graph, Node *node) {
auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("transpose_X")); auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("transpose_X"));
auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("transpose_Y")); auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("transpose_Y"));
auto alpha = BOOST_GET_CONST(float, op->GetAttr("alpha")); auto alpha = BOOST_GET_CONST(float, op->GetAttr("alpha"));
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape(); Node *x_node = GetInputVarNode("X", node);
auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape(); Node *y_node = GetInputVarNode("Y", node);
int x_rank = x_node->Var()->GetShape().size();
int y_rank = y_node->Var()->GetShape().size();
int x_rank = x_shape.size(); auto gen_perm = [](const int rank) -> std::vector<int64_t> {
std::vector<int64_t> perm; std::vector<int64_t> perm;
if (x_rank == 1) { if (rank == 1) {
perm = std::vector<int64_t>{0}; perm = std::vector<int64_t>{0};
} else if (x_rank == 2) { } else if (rank == 2) {
perm = std::vector<int64_t>{1, 0};
} else if (rank == 3) {
perm = std::vector<int64_t>{0, 2, 1};
} else if (rank == 4) {
perm = std::vector<int64_t>{0, 1, 3, 2};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"op matmul with input rank == %d", rank));
}
return perm;
};
if (x_rank == 2) {
if (!transpose_x && !transpose_y && is_float_equal(alpha, 1.0f)) { if (!transpose_x && !transpose_y && is_float_equal(alpha, 1.0f)) {
return CreateBaseOp( return CreateBaseOp(
graph, graph,
...@@ -137,18 +152,10 @@ Node *matmul_handler(Graph *graph, Node *node) { ...@@ -137,18 +152,10 @@ Node *matmul_handler(Graph *graph, Node *node) {
transpose_x, transpose_x,
transpose_y, transpose_y,
alpha); alpha);
} else if (x_rank == 3) {
perm = std::vector<int64_t>{0, 2, 1};
} else if (x_rank == 4) {
perm = std::vector<int64_t>{0, 1, 3, 2};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"op matmul with input rank == %d", x_rank));
} }
Node *x_node = GetInputVarNode("X", node);
Node *y_node = GetInputVarNode("Y", node);
if (transpose_x) { if (transpose_x) {
auto perm = gen_perm(x_rank);
x_node = CreateBaseOp(graph, x_node = CreateBaseOp(graph,
node, node,
"popart_transpose", "popart_transpose",
...@@ -158,6 +165,7 @@ Node *matmul_handler(Graph *graph, Node *node) { ...@@ -158,6 +165,7 @@ Node *matmul_handler(Graph *graph, Node *node) {
x_node = x_node->outputs[0]; x_node = x_node->outputs[0];
} }
if (transpose_y) { if (transpose_y) {
auto perm = gen_perm(y_rank);
y_node = CreateBaseOp(graph, y_node = CreateBaseOp(graph,
node, node,
"popart_transpose", "popart_transpose",
...@@ -209,7 +217,7 @@ Node *scale_handler(Graph *graph, Node *node) { ...@@ -209,7 +217,7 @@ Node *scale_handler(Graph *graph, Node *node) {
CreateCast(graph, node, {GetInputVarNode("X", node)}, {}, VarType::FP32); CreateCast(graph, node, {GetInputVarNode("X", node)}, {}, VarType::FP32);
Node *result = nullptr; Node *result = nullptr;
if (!op->Input("ScaleTensor").empty()) { if (op->InputArgumentNames().size() > 1) {
auto scale = GetInputVarNode("ScaleTensor", node); auto scale = GetInputVarNode("ScaleTensor", node);
if (is_float_equal(bias_, 0.0)) { if (is_float_equal(bias_, 0.0)) {
result = CreateBaseOp( result = CreateBaseOp(
...@@ -321,183 +329,6 @@ Node *scale_handler(Graph *graph, Node *node) { ...@@ -321,183 +329,6 @@ Node *scale_handler(Graph *graph, Node *node) {
return result_after_cast; return result_after_cast;
} }
Node *cross_entropy2_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index"));
Node *new_cast = nullptr;
if (GetInputVarNode("Label", node)->Var()->GetDataType() == VarType::INT32) {
new_cast = GetInputVarNode("Label", node);
} else {
auto new_cast = CreateCast(
graph, node, {GetInputVarNode("Label", node)}, {}, VarType::INT32);
new_cast = new_cast->outputs[0];
}
auto label_shape_ = GetInputVarNode("Label", node)->Var()->GetShape();
if (label_shape_[label_shape_.size() - 1] != 1) {
auto log = CreateBaseOp(
graph, node, "popart_log", {GetInputVarNode("X", node)}, {}, {});
return CreateBaseOp(
graph,
node,
"popart_nllloss_v2",
{log->outputs[0], new_cast},
{GetOutputVarNode("Y", node)},
{
{"reduction", 2}, // popart::ReductionType::NoReduction
{"ignoreIndex", ignoreIndex},
{"inputIsLogProbability", true},
});
} else {
std::vector<int64_t> new_shape_{label_shape_[0]};
auto const_before_loss = CreateBaseOp(
graph,
node,
"popart_constant",
{},
{},
{{"value", new_shape_},
{"dims",
std::vector<int64_t>{static_cast<int64_t>(new_shape_.size())}},
{"dtype", ONNXDataType::INT64}});
auto reshape_before_loss =
CreateBaseOp(graph,
node,
"popart_reshape",
{new_cast, const_before_loss->outputs[0]},
{},
{});
auto log = CreateBaseOp(
graph, node, "popart_log", {GetInputVarNode("X", node)}, {}, {});
auto nllloss = CreateBaseOp(
graph,
node,
"popart_nllloss_v2",
{log->outputs[0], reshape_before_loss->outputs[0]},
{},
{
{"reduction", 2}, // popart::ReductionType::NoReduction
{"ignoreIndex", ignoreIndex},
{"inputIsLogProbability", true},
});
auto const_after_loss = CreateBaseOp(
graph,
node,
"popart_constant",
{},
{},
{{"value", label_shape_},
{"dims",
std::vector<int64_t>{static_cast<int64_t>(label_shape_.size())}},
{"dtype", ONNXDataType::INT64}});
auto reshape_after_loss =
CreateBaseOp(graph,
node,
"popart_reshape",
{nllloss->outputs[0], const_after_loss->outputs[0]},
{GetOutputVarNode("Y", node)},
{});
return reshape_after_loss;
}
}
Node *softmax_with_cross_entropy_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index"));
auto axis = BOOST_GET_CONST(int, op->GetAttr("axis"));
auto soft_label = BOOST_GET_CONST(bool, op->GetAttr("soft_label"));
if (soft_label) {
PADDLE_THROW(platform::errors::InvalidArgument(
"soft_label is not supported yet in IPU"));
}
Node *new_cast = nullptr;
if (GetInputVarNode("Label", node)->Var()->GetDataType() == VarType::INT32) {
new_cast = GetInputVarNode("Label", node);
} else {
auto new_cast = CreateCast(
graph, node, {GetInputVarNode("Label", node)}, {}, VarType::INT32);
new_cast = new_cast->outputs[0];
}
auto softmax_node = CreateSoftmaxOpset11(
graph, node, {GetInputVarNode("Logits", node)}, {}, axis);
auto label_shape_ = GetInputVarNode("Label", node)->Var()->GetShape();
if (label_shape_[label_shape_.size() - 1] != 1) {
auto log = CreateBaseOp(
graph, node, "popart_log", {softmax_node->outputs[0]}, {}, {});
// softmax_with_cross_entropy is split to several ops in python.
// reduction is not needed here.
return CreateBaseOp(
graph,
node,
"popart_nllloss_v2",
{log->outputs[0], new_cast},
{GetOutputVarNode("Loss", node)},
{
{"reduction", 2}, // popart::ReductionType::NoReduction
{"ignoreIndex", ignoreIndex},
{"inputIsLogProbability", true},
});
} else {
std::vector<int64_t> new_shape_{label_shape_[0]};
auto const_before_loss = CreateBaseOp(
graph,
node,
"popart_constant",
{},
{},
{{"value", new_shape_},
{"dims",
std::vector<int64_t>{static_cast<int64_t>(new_shape_.size())}},
{"dtype", ONNXDataType::INT64}});
auto reshape_before_loss =
CreateBaseOp(graph,
node,
"popart_reshape",
{new_cast, const_before_loss->outputs[0]},
{},
{});
auto log = CreateBaseOp(
graph, node, "popart_log", {softmax_node->outputs[0]}, {}, {});
auto nllloss = CreateBaseOp(
graph,
node,
"popart_nllloss_v2",
{log->outputs[0], reshape_before_loss->outputs[0]},
{},
{
{"reduction", 2}, // popart::ReductionType::NoReduction
{"ignoreIndex", ignoreIndex},
{"inputIsLogProbability", true},
});
auto const_after_loss = CreateBaseOp(
graph,
node,
"popart_constant",
{},
{},
{{"value", label_shape_},
{"dims",
std::vector<int64_t>{static_cast<int64_t>(label_shape_.size())}},
{"dtype", ONNXDataType::INT64}});
auto reshape_after_loss =
CreateBaseOp(graph,
node,
"popart_reshape",
{nllloss->outputs[0], const_after_loss->outputs[0]},
{GetOutputVarNode("Loss", node)},
{});
return reshape_after_loss;
}
}
Node *cumsum_handler(Graph *graph, Node *node) { Node *cumsum_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto exclusive = BOOST_GET_CONST(bool, op->GetAttr("exclusive")); auto exclusive = BOOST_GET_CONST(bool, op->GetAttr("exclusive"));
...@@ -512,41 +343,63 @@ Node *cumsum_handler(Graph *graph, Node *node) { ...@@ -512,41 +343,63 @@ Node *cumsum_handler(Graph *graph, Node *node) {
{{"value", std::vector<int64_t>{axis}}, {{"value", std::vector<int64_t>{axis}},
{"dims", std::vector<int64_t>{1}}, {"dims", std::vector<int64_t>{1}},
{"dtype", ONNXDataType::INT64}}); {"dtype", ONNXDataType::INT64}});
return CreateBaseOp( Node *input_x = nullptr;
auto data_type_ = GetInputVarNode("X", node)->Var()->GetDataType();
bool need_cast = data_type_ != VarType::FP32;
std::vector<Node *> cumsum_out;
if (need_cast) {
auto cast_x = CreateCast(
graph, node, {GetInputVarNode("X", node)}, {}, VarType::FP32);
input_x = cast_x->outputs[0];
} else {
input_x = GetInputVarNode("X", node);
cumsum_out.emplace_back(GetOutputVarNode("Out", node));
}
auto cumsum_node = CreateBaseOp(
graph, graph,
node, node,
"popart_cumsum", "popart_cumsum",
{GetInputVarNode("X", node), axis_node->outputs[0]}, {input_x, axis_node->outputs[0]},
{GetOutputVarNode("Out", node)}, cumsum_out,
{{"exclusive", popart_exclusive}, {"reverse", popart_reverse}}); {{"exclusive", popart_exclusive}, {"reverse", popart_reverse}});
if (need_cast) {
cumsum_node = CreateCast(graph,
node,
cumsum_node->outputs,
{GetOutputVarNode("Out", node)},
data_type_);
}
return cumsum_node;
} }
Node *matmul_v2_handler(Graph *graph, Node *node) { Node *matmul_v2_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("trans_x")); auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("trans_x"));
auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("trans_y")); auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("trans_y"));
auto x_shape = GetInputVarNode("X", node)->Var()->GetShape(); Node *x_node = GetInputVarNode("X", node);
auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape(); Node *y_node = GetInputVarNode("Y", node);
int x_rank = x_node->Var()->GetShape().size();
int y_rank = y_node->Var()->GetShape().size();
auto gen_perm = [](const int rank) -> std::vector<int64_t> {
std::vector<int64_t> perm; std::vector<int64_t> perm;
int x_rank = x_shape.size(); if (rank == 1) {
if (x_rank == 1) {
perm = std::vector<int64_t>{0}; perm = std::vector<int64_t>{0};
} else if (x_rank == 2) { } else if (rank == 2) {
perm = std::vector<int64_t>{1, 0}; perm = std::vector<int64_t>{1, 0};
} else if (x_rank == 3) { } else if (rank == 3) {
perm = std::vector<int64_t>{0, 2, 1}; perm = std::vector<int64_t>{0, 2, 1};
} else if (x_rank == 4) { } else if (rank == 4) {
perm = std::vector<int64_t>{0, 1, 3, 2}; perm = std::vector<int64_t>{0, 1, 3, 2};
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"op matmul with input rank == %d", x_rank)); "op matmul with input rank == %d", rank));
} }
return perm;
Node *x_node = GetInputVarNode("X", node); };
Node *y_node = GetInputVarNode("Y", node);
if (transpose_x) { if (transpose_x) {
auto perm = gen_perm(x_rank);
x_node = CreateBaseOp(graph, x_node = CreateBaseOp(graph,
node, node,
"popart_transpose", "popart_transpose",
...@@ -556,6 +409,7 @@ Node *matmul_v2_handler(Graph *graph, Node *node) { ...@@ -556,6 +409,7 @@ Node *matmul_v2_handler(Graph *graph, Node *node) {
x_node = x_node->outputs[0]; x_node = x_node->outputs[0];
} }
if (transpose_y) { if (transpose_y) {
auto perm = gen_perm(y_rank);
y_node = CreateBaseOp(graph, y_node = CreateBaseOp(graph,
node, node,
"popart_transpose", "popart_transpose",
...@@ -611,9 +465,6 @@ REGISTER_HANDLER(matmul, matmul_handler); ...@@ -611,9 +465,6 @@ REGISTER_HANDLER(matmul, matmul_handler);
REGISTER_HANDLER(sum, sum_handler); REGISTER_HANDLER(sum, sum_handler);
REGISTER_HANDLER(softmax, softmax_handler); REGISTER_HANDLER(softmax, softmax_handler);
REGISTER_HANDLER(scale, scale_handler); REGISTER_HANDLER(scale, scale_handler);
REGISTER_HANDLER(softmax_with_cross_entropy,
softmax_with_cross_entropy_handler);
REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler);
REGISTER_HANDLER(cumsum, cumsum_handler); REGISTER_HANDLER(cumsum, cumsum_handler);
REGISTER_HANDLER(matmul_v2, matmul_v2_handler); REGISTER_HANDLER(matmul_v2, matmul_v2_handler);
REGISTER_HANDLER(bmm, bmm_handler); REGISTER_HANDLER(bmm, bmm_handler);
......
...@@ -55,9 +55,20 @@ Node *MakeOpNode(Graph *graph, ...@@ -55,9 +55,20 @@ Node *MakeOpNode(Graph *graph,
op_desc->SetType(type); op_desc->SetType(type);
auto op = graph->CreateOpNode(op_desc.get()); auto op = graph->CreateOpNode(op_desc.get());
// inputs
std::vector<std::string> input_names;
for (auto *in : inputs) { for (auto *in : inputs) {
if (in != nullptr) {
ConnectNodes(in, op); ConnectNodes(in, op);
input_names.push_back(in->Name());
} else {
input_names.push_back(std::string(""));
}
} }
op->Op()->SetInput("__inputs__", input_names);
// outputs
std::vector<std::string> output_names;
if (outputs.empty()) { if (outputs.empty()) {
auto var = MakeVarNode(graph, node); auto var = MakeVarNode(graph, node);
ConnectNodes(op, var); ConnectNodes(op, var);
...@@ -66,14 +77,6 @@ Node *MakeOpNode(Graph *graph, ...@@ -66,14 +77,6 @@ Node *MakeOpNode(Graph *graph,
ConnectNodes(op, out); ConnectNodes(op, out);
} }
} }
// i/o
std::vector<std::string> input_names;
for (auto node : op->inputs) {
input_names.push_back(node->Name());
}
op->Op()->SetInput("__inputs__", input_names);
std::vector<std::string> output_names;
for (auto node : op->outputs) { for (auto node : op->outputs) {
output_names.push_back(node->Name()); output_names.push_back(node->Name());
} }
...@@ -138,6 +141,19 @@ Node *CreateCast(Graph *graph, ...@@ -138,6 +141,19 @@ Node *CreateCast(Graph *graph,
graph, node, "popart_cast", inputs, outputs, {{"to", to}}); graph, node, "popart_cast", inputs, outputs, {{"to", to}});
} }
Node *CreateIdentityLossOp(Graph *graph,
Node *node,
const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs,
int reduction) {
return CreateBaseOp(graph,
node,
"popart_identity_loss",
inputs,
outputs,
{{"reduction", reduction}});
}
Node *CreateGemm(Graph *graph, Node *CreateGemm(Graph *graph,
Node *node, Node *node,
const std::vector<Node *> &inputs, const std::vector<Node *> &inputs,
......
...@@ -67,6 +67,12 @@ Node *CreateCast(Graph *graph, ...@@ -67,6 +67,12 @@ Node *CreateCast(Graph *graph,
const std::vector<Node *> &outputs, const std::vector<Node *> &outputs,
const VarType::Type otype); const VarType::Type otype);
Node *CreateIdentityLossOp(Graph *graph,
Node *node,
const std::vector<Node *> &inputs,
const std::vector<Node *> &outputs,
int reduction);
Node *CreateGemm(Graph *graph, Node *CreateGemm(Graph *graph,
Node *node, Node *node,
const std::vector<Node *> &inputs, const std::vector<Node *> &inputs,
......
...@@ -85,17 +85,6 @@ Node *identity_handler(Graph *graph, Node *node) { ...@@ -85,17 +85,6 @@ Node *identity_handler(Graph *graph, Node *node) {
graph, node, "popart_identity", node->inputs, node->outputs); graph, node, "popart_identity", node->inputs, node->outputs);
} }
Node *identity_loss_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto reduction = BOOST_GET_CONST(int, op->GetAttr("reduction"));
return CreateBaseOp(graph,
node,
"popart_identity_loss",
node->inputs,
node->outputs,
{{"reduction", reduction}});
}
Node *detach_handler(Graph *graph, Node *node) { Node *detach_handler(Graph *graph, Node *node) {
return CreateBaseOp( return CreateBaseOp(
graph, node, "popart_detach_v2", node->inputs, node->outputs); graph, node, "popart_detach_v2", node->inputs, node->outputs);
...@@ -112,5 +101,4 @@ REGISTER_HANDLER(popart_optimizer, popart_optimizer_handler); ...@@ -112,5 +101,4 @@ REGISTER_HANDLER(popart_optimizer, popart_optimizer_handler);
REGISTER_HANDLER(checkpointoutput, checkpointoutput_handler); REGISTER_HANDLER(checkpointoutput, checkpointoutput_handler);
REGISTER_HANDLER(custom_nll_loss, custom_nll_loss_handler); REGISTER_HANDLER(custom_nll_loss, custom_nll_loss_handler);
REGISTER_HANDLER(identity, identity_handler); REGISTER_HANDLER(identity, identity_handler);
REGISTER_HANDLER(identity_loss, identity_loss_handler);
REGISTER_HANDLER(detach, detach_handler); REGISTER_HANDLER(detach, detach_handler);
...@@ -36,6 +36,27 @@ Node *reduce_op_handler(Graph *graph, Node *node, const std::string &op_name) { ...@@ -36,6 +36,27 @@ Node *reduce_op_handler(Graph *graph, Node *node, const std::string &op_name) {
return CreateBaseOp(graph, node, op_name, node->inputs, node->outputs, attrs); return CreateBaseOp(graph, node, op_name, node->inputs, node->outputs, attrs);
} }
Node *reduce_all_op_handler(Graph *graph,
Node *node,
const std::string &op_name) {
auto *op = node->Op();
auto attrs = AttributeMap{};
auto reduce_all = BOOST_GET_CONST(bool, op->GetAttr("reduce_all"));
if (!reduce_all) {
auto axes_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("dim"));
auto axes = std::vector<int64_t>{axes_.begin(), axes_.end()};
attrs.emplace("axes", axes);
}
auto keepdims_ = BOOST_GET_CONST(bool, op->GetAttr("keep_dim"));
auto keepdims = int64_t{keepdims_};
attrs.emplace("keepdims", keepdims);
auto int32_x =
CreateCast(graph, node, node->inputs, {}, VarType::INT32)->outputs[0];
auto reduce_op = CreateBaseOp(graph, node, op_name, {int32_x}, {}, attrs);
return CreateCast(
graph, node, reduce_op->outputs, node->outputs, VarType::BOOL);
}
Node *reduce_mean_handler(Graph *graph, Node *node) { Node *reduce_mean_handler(Graph *graph, Node *node) {
return reduce_op_handler(graph, node, "popart_reducemean"); return reduce_op_handler(graph, node, "popart_reducemean");
} }
...@@ -56,6 +77,34 @@ Node *reduce_prod_handler(Graph *graph, Node *node) { ...@@ -56,6 +77,34 @@ Node *reduce_prod_handler(Graph *graph, Node *node) {
return reduce_op_handler(graph, node, "popart_reduceprod"); return reduce_op_handler(graph, node, "popart_reduceprod");
} }
Node *logsumexp_handler(Graph *graph, Node *node) {
auto *op = node->Op();
auto attrs = AttributeMap{};
auto reduce_all = BOOST_GET_CONST(bool, op->GetAttr("reduce_all"));
if (!reduce_all) {
auto axes_ = BOOST_GET_CONST(std::vector<int>, op->GetAttr("axis"));
auto axes = std::vector<int64_t>{axes_.begin(), axes_.end()};
attrs.emplace("axes", axes);
}
auto keepdims_ = BOOST_GET_CONST(bool, op->GetAttr("keepdim"));
auto keepdims = int64_t{keepdims_};
attrs.emplace("keepdims", keepdims);
return CreateBaseOp(graph,
node,
"popart_reducelogsumexp",
node->inputs,
node->outputs,
attrs);
}
Node *reduce_all_handler(Graph *graph, Node *node) {
return reduce_all_op_handler(graph, node, "popart_reducemin");
}
Node *reduce_any_handler(Graph *graph, Node *node) {
return reduce_all_op_handler(graph, node, "popart_reducemax");
}
} // namespace } // namespace
} // namespace ipu } // namespace ipu
} // namespace platform } // namespace platform
...@@ -66,3 +115,6 @@ REGISTER_HANDLER(reduce_min, reduce_min_handler); ...@@ -66,3 +115,6 @@ REGISTER_HANDLER(reduce_min, reduce_min_handler);
REGISTER_HANDLER(reduce_sum, reduce_sum_handler); REGISTER_HANDLER(reduce_sum, reduce_sum_handler);
REGISTER_HANDLER(reduce_max, reduce_max_handler); REGISTER_HANDLER(reduce_max, reduce_max_handler);
REGISTER_HANDLER(reduce_prod, reduce_prod_handler); REGISTER_HANDLER(reduce_prod, reduce_prod_handler);
REGISTER_HANDLER(logsumexp, logsumexp_handler);
REGISTER_HANDLER(reduce_all, reduce_all_handler);
REGISTER_HANDLER(reduce_any, reduce_any_handler);
...@@ -33,6 +33,7 @@ OP_DECL(popart_dynamicadd_v2, aiGraphcoreOpset.dynamicadd, ARG(INT_VEC,axes) ARG ...@@ -33,6 +33,7 @@ OP_DECL(popart_dynamicadd_v2, aiGraphcoreOpset.dynamicadd, ARG(INT_VEC,axes) ARG
OP_DECL(popart_sequenceslice_v2, aiGraphcoreOpset.sequenceslice, ARG(INT,zeroUnused) ) // NOLINT OP_DECL(popart_sequenceslice_v2, aiGraphcoreOpset.sequenceslice, ARG(INT,zeroUnused) ) // NOLINT
OP_DECL(popart_replicatedallreduce_v2, aiGraphcoreOpset.replicatedallreduce, OPT_ARG(INT_VEC,commGroup) ) // NOLINT OP_DECL(popart_replicatedallreduce_v2, aiGraphcoreOpset.replicatedallreduce, OPT_ARG(INT_VEC,commGroup) ) // NOLINT
OP_DECL(popart_ctcbeamsearchdecoder_v2, aiGraphcoreOpset.ctcbeamsearchdecoder, ARG(INT,blank) ARG(INT,beamWidth) ARG(INT,topPaths) ) // NOLINT OP_DECL(popart_ctcbeamsearchdecoder_v2, aiGraphcoreOpset.ctcbeamsearchdecoder, ARG(INT,blank) ARG(INT,beamWidth) ARG(INT,topPaths) ) // NOLINT
OP_DECL(popart_ctcloss, aiGraphcoreOpset.ctcloss, SIG_ARG(INT32,popart::ReductionType,reduction) ARG(INT32,blank) ARG(STRING,outDataType) ) // NOLINT
OP_DECL(popart_shapeddropout_v2, aiGraphcoreOpset.shapeddropout, ARG(INT_VEC,shape) ARG(FLOAT,ratio) ) // NOLINT OP_DECL(popart_shapeddropout_v2, aiGraphcoreOpset.shapeddropout, ARG(INT_VEC,shape) ARG(FLOAT,ratio) ) // NOLINT
OP_DECL(popart_atan2_v2, aiGraphcoreOpset.atan2, NONE) // NOLINT OP_DECL(popart_atan2_v2, aiGraphcoreOpset.atan2, NONE) // NOLINT
OP_DECL(popart_expm1_v2, aiGraphcoreOpset.expm1, NONE) // NOLINT OP_DECL(popart_expm1_v2, aiGraphcoreOpset.expm1, NONE) // NOLINT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册