未验证 提交 6c935e1d 编写于 作者: A Allen Guo 提交者: GitHub

set custom_nll_loss_op attr ignoreIndex to str (#42596)

set attr ignoreIndex type to string for custom_nllloss_op

部分 cheery-pick of #42534
上级 37715dab
...@@ -57,14 +57,21 @@ Node *checkpointoutput_handler(Graph *graph, Node *node) { ...@@ -57,14 +57,21 @@ Node *checkpointoutput_handler(Graph *graph, Node *node) {
Node *custom_nll_loss_handler(Graph *graph, Node *node) { Node *custom_nll_loss_handler(Graph *graph, Node *node) {
auto *op = node->Op(); auto *op = node->Op();
auto reduction = BOOST_GET_CONST(int, op->GetAttr("reduction")); auto reduction = BOOST_GET_CONST(int, op->GetAttr("reduction"));
auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignoreIndex")); auto ignoreIndex = BOOST_GET_CONST(std::string, op->GetAttr("ignoreIndex"));
auto inputIsLogProbability = auto inputIsLogProbability =
BOOST_GET_CONST(bool, op->GetAttr("inputIsLogProbability")); BOOST_GET_CONST(bool, op->GetAttr("inputIsLogProbability"));
if (ignoreIndex == "None") {
return CreateBaseOp(graph, node, "popart_nllloss_v2", node->inputs, return CreateBaseOp(graph, node, "popart_nllloss_v2", node->inputs,
node->outputs, node->outputs,
{{"reduction", reduction}, {{"reduction", reduction},
{"ignoreIndex", ignoreIndex},
{"inputIsLogProbability", inputIsLogProbability}}); {"inputIsLogProbability", inputIsLogProbability}});
} else {
return CreateBaseOp(graph, node, "popart_nllloss_v2", node->inputs,
node->outputs,
{{"reduction", reduction},
{"ignoreIndex", std::atoi(ignoreIndex.c_str())},
{"inputIsLogProbability", inputIsLogProbability}});
}
} }
Node *identity_handler(Graph *graph, Node *node) { Node *identity_handler(Graph *graph, Node *node) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册