diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc index c9ac081f920dab27dc64d42f71342ae4c921c978..74f262be8477a1b515110563ad1b4c4823d74df0 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc @@ -57,14 +57,21 @@ Node *checkpointoutput_handler(Graph *graph, Node *node) { Node *custom_nll_loss_handler(Graph *graph, Node *node) { auto *op = node->Op(); auto reduction = BOOST_GET_CONST(int, op->GetAttr("reduction")); - auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignoreIndex")); + auto ignoreIndex = BOOST_GET_CONST(std::string, op->GetAttr("ignoreIndex")); auto inputIsLogProbability = BOOST_GET_CONST(bool, op->GetAttr("inputIsLogProbability")); - return CreateBaseOp(graph, node, "popart_nllloss_v2", node->inputs, - node->outputs, - {{"reduction", reduction}, - {"ignoreIndex", ignoreIndex}, - {"inputIsLogProbability", inputIsLogProbability}}); + if (ignoreIndex == "None") { + return CreateBaseOp(graph, node, "popart_nllloss_v2", node->inputs, + node->outputs, + {{"reduction", reduction}, + {"inputIsLogProbability", inputIsLogProbability}}); + } else { + return CreateBaseOp(graph, node, "popart_nllloss_v2", node->inputs, + node->outputs, + {{"reduction", reduction}, + {"ignoreIndex", std::atoi(ignoreIndex.c_str())}, + {"inputIsLogProbability", inputIsLogProbability}}); + } } Node *identity_handler(Graph *graph, Node *node) {