From 6c935e1dc3fa4e97525ac97a605e1e79ae5e645e Mon Sep 17 00:00:00 2001 From: Allen Guo Date: Tue, 10 May 2022 10:35:56 +0800 Subject: [PATCH] set custom_nll_loss_op attr ignoreIndex to str (#42596) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit set attr ignoreIndex type to string for custom_nllloss_op 部分 cheery-pick of #42534 --- .../ipu/popart_canonicalization/other_ops.cc | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc index c9ac081f92..74f262be84 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc @@ -57,14 +57,21 @@ Node *checkpointoutput_handler(Graph *graph, Node *node) { Node *custom_nll_loss_handler(Graph *graph, Node *node) { auto *op = node->Op(); auto reduction = BOOST_GET_CONST(int, op->GetAttr("reduction")); - auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignoreIndex")); + auto ignoreIndex = BOOST_GET_CONST(std::string, op->GetAttr("ignoreIndex")); auto inputIsLogProbability = BOOST_GET_CONST(bool, op->GetAttr("inputIsLogProbability")); - return CreateBaseOp(graph, node, "popart_nllloss_v2", node->inputs, - node->outputs, - {{"reduction", reduction}, - {"ignoreIndex", ignoreIndex}, - {"inputIsLogProbability", inputIsLogProbability}}); + if (ignoreIndex == "None") { + return CreateBaseOp(graph, node, "popart_nllloss_v2", node->inputs, + node->outputs, + {{"reduction", reduction}, + {"inputIsLogProbability", inputIsLogProbability}}); + } else { + return CreateBaseOp(graph, node, "popart_nllloss_v2", node->inputs, + node->outputs, + {{"reduction", reduction}, + {"ignoreIndex", std::atoi(ignoreIndex.c_str())}, + {"inputIsLogProbability", inputIsLogProbability}}); + } } Node *identity_handler(Graph *graph, Node *node) { -- GitLab