未验证 提交 8bf7cd85 编写于 作者: A Allen Guo 提交者: GitHub

[IPU] add more loss ops (#44646)

* add more loss ops

* add authors
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>
Co-authored-by: NZhaorui Chen <zhaoruic@graphcore.ai>
上级 8a07d02c
......@@ -278,6 +278,123 @@ Node *kldiv_loss_handler(Graph *graph, Node *node) {
return loss;
}
Node *sigmoid_cross_entropy_with_logits_handler(Graph *graph, Node *node) {
// Out = max(logits, 0) - logits * label + log(1 + exp(-abs(logits)))
auto *op = node->Op();
int reduction = 2;
if (is_dynamic_graph()) {
reduction = RemoveTailReduction(graph, node, "Out");
}
bool append_identity_loss =
is_dynamic_graph() && IsLastVarNode(GetOutputVarNode("Out", node));
auto logits = GetInputVarNode("X", node);
auto label = GetInputVarNode("Label", node);
// sigmoid_cross_entropy_with_logits uses float label as input.
auto ignore_index_value =
static_cast<float>(PADDLE_GET_CONST(int, op->GetAttr("ignore_index")));
auto normalize = PADDLE_GET_CONST(bool, op->GetAttr("normalize"));
// const
auto one = CreateConst(
graph, node, std::vector<float>{1.0}, {1}, GetVarDType(logits))
->outputs.front();
auto zero =
CreateConst(
graph, node, std::vector<float>{0.0}, {1}, GetVarDType(logits))
->outputs.front();
auto ignore_index = CreateConst(graph,
node,
std::vector<float>{ignore_index_value},
{1},
GetVarDType(label))
->outputs.front();
// max(logits, 0)
auto max_zero =
CreateBaseOp(graph, node, "popart_max", {logits, zero}, {}, {})
->outputs.front();
// logits * label
auto mul = CreateBaseOp(graph, node, "popart_mul", {logits, label}, {}, {})
->outputs.front();
// abs(logits)
auto abs = CreateBaseOp(graph, node, "popart_abs", {logits}, {}, {})
->outputs.front();
// -abs(logits)
auto neg_abs =
CreateBaseOp(graph, node, "popart_neg", {abs}, {}, {})->outputs.front();
// exp(-abs(logits))
auto exp_neg_abs = CreateBaseOp(graph, node, "popart_exp", {neg_abs}, {}, {})
->outputs.front();
// 1+exp(-abs(logits))
auto log_term =
CreateBaseOp(graph, node, "popart_add", {exp_neg_abs, one}, {}, {})
->outputs.front();
// log(1+exp(-abs(logits)))
auto log = CreateBaseOp(graph, node, "popart_log", {log_term}, {}, {})
->outputs.front();
// max(logits, 0) - logits * label
auto sub = CreateBaseOp(graph, node, "popart_sub", {max_zero, mul}, {}, {})
->outputs.front();
// max(logits, 0) - logits * label + log(1 + exp(-abs(logits)))
auto loss = CreateBaseOp(graph, node, "popart_add", {sub, log}, {}, {})
->outputs.front();
// label == ignore_index ? 0 : loss
auto equal_cond =
CreateBaseOp(graph, node, "popart_equal", {label, ignore_index}, {}, {})
->outputs.front();
loss = CreateBaseOp(graph,
node,
"popart_where",
{equal_cond, zero, loss},
append_identity_loss || normalize
? std::vector<Node *>{}
: std::vector<Node *>{GetOutputVarNode("Out", node)},
{});
if (normalize) {
// normalize the output as: loss = loss / sum(label != ignore_index)
auto not_equal =
CreateBaseOp(graph, node, "popart_logical_not", {equal_cond}, {}, {})
->outputs.front();
auto mask =
CreateCast(graph, node, {not_equal}, {}, logits->Var()->GetDataType())
->outputs.front();
auto sum = CreateBaseOp(graph,
node,
"popart_reducesum",
{mask},
{},
{{"keepdims", int64_t{0}}})
->outputs.front();
auto eps =
CreateConst(
graph, node, std::vector<float>{1e-5}, {1}, GetVarDType(logits))
->outputs.front();
// avoid division by zero
auto add_eps = CreateBaseOp(graph, node, "popart_add", {sum, eps}, {}, {})
->outputs.front();
loss =
CreateBaseOp(graph,
node,
"popart_div",
{loss->outputs[0], add_eps},
append_identity_loss
? std::vector<Node *>{}
: std::vector<Node *>{GetOutputVarNode("Out", node)},
{});
}
if (append_identity_loss) {
loss = CreateIdentityLossOp(
graph, node, loss->outputs, {GetOutputVarNode("Out", node)}, reduction);
}
return loss;
}
Node *binary_cross_entropy_handler(Graph *graph, Node *node) {
// Out = -1 * weight * (label * log(x) + (1 - label) * log(1 - x))
int reduction = 2;
......@@ -493,6 +610,97 @@ Node *warpctc_handler(Graph *graph, Node *node) {
return loss;
}
Node *rank_loss_handler(Graph *graph, Node *node) {
// (1.0f + (left - right).exp()).log() - label * (left - right)
auto label = GetInputVarNode("Label", node);
auto left = GetInputVarNode("Left", node);
auto right = GetInputVarNode("Right", node);
auto output = GetOutputVarNode("Out", node);
int reduction = 2;
if (is_dynamic_graph()) {
reduction = RemoveTailReduction(graph, node, "Out");
}
bool append_identity_loss = is_dynamic_graph() && IsLastVarNode(output);
auto sub = CreateBaseOp(graph, node, "popart_sub", {left, right}, {}, {})
->outputs.front();
auto mul = CreateBaseOp(graph, node, "popart_mul", {label, sub}, {}, {})
->outputs.front();
// const
auto one =
CreateConst(graph, node, std::vector<float>{1.0}, {1}, GetVarDType(label))
->outputs.front();
auto exp =
CreateBaseOp(graph, node, "popart_exp", {sub}, {}, {})->outputs.front();
auto add = CreateBaseOp(graph, node, "popart_add", {one, exp}, {}, {})
->outputs.front();
auto log =
CreateBaseOp(graph, node, "popart_log", {add}, {}, {})->outputs.front();
auto loss = CreateBaseOp(graph,
node,
"popart_sub",
{log, mul},
append_identity_loss ? std::vector<Node *>{}
: std::vector<Node *>{output},
{})
->outputs.front();
if (append_identity_loss) {
loss =
CreateIdentityLossOp(graph, node, loss->outputs, {output}, reduction);
}
return loss;
}
Node *margin_rank_loss_handler(Graph *graph, Node *node) {
// rank_loss = max(0, -label * (left - right) + margin)
auto *op = node->Op();
auto label = GetInputVarNode("Label", node);
auto left = GetInputVarNode("X1", node);
auto right = GetInputVarNode("X2", node);
auto output = GetOutputVarNode("Out", node);
auto margin_value = PADDLE_GET_CONST(float, op->GetAttr("margin"));
int reduction = 2;
if (is_dynamic_graph()) {
reduction = RemoveTailReduction(graph, node, "Out");
}
bool append_identity_loss = is_dynamic_graph() && IsLastVarNode(output);
// -(left - right)
auto sub = CreateBaseOp(graph, node, "popart_sub", {right, left}, {}, {})
->outputs.front();
// -label * (left - right)
auto mul = CreateBaseOp(graph, node, "popart_mul", {label, sub}, {}, {})
->outputs.front();
// const
auto zero =
CreateConst(graph, node, std::vector<float>{0.0}, {1}, GetVarDType(label))
->outputs.front();
auto margin = CreateConst(graph,
node,
std::vector<float>{margin_value},
{1},
GetVarDType(label))
->outputs.front();
auto margin_add =
CreateBaseOp(graph, node, "popart_add", {mul, margin}, {}, {})
->outputs.front();
// max(0, term)
auto loss = CreateBaseOp(graph,
node,
"popart_max",
{zero, margin_add},
append_identity_loss ? std::vector<Node *>{}
: std::vector<Node *>{output},
{})
->outputs.front();
if (append_identity_loss) {
loss =
CreateIdentityLossOp(graph, node, loss->outputs, {output}, reduction);
}
return loss;
}
} // namespace
} // namespace ipu
} // namespace platform
......@@ -502,7 +710,11 @@ REGISTER_HANDLER(identity_loss, identity_loss_handler);
REGISTER_HANDLER(softmax_with_cross_entropy,
softmax_with_cross_entropy_handler);
REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler);
REGISTER_HANDLER(sigmoid_cross_entropy_with_logits,
sigmoid_cross_entropy_with_logits_handler);
REGISTER_HANDLER(kldiv_loss, kldiv_loss_handler);
REGISTER_HANDLER(bce_loss, binary_cross_entropy_handler);
REGISTER_HANDLER(huber_loss, huber_loss_handler);
REGISTER_HANDLER(warpctc, warpctc_handler);
REGISTER_HANDLER(rank_loss, rank_loss_handler);
REGISTER_HANDLER(margin_rank_loss, margin_rank_loss_handler);
......@@ -70,8 +70,8 @@ class TestBase(IPUD2STest):
self.loss_op = paddle.fluid.layers.cross_entropy
def set_data_feed(self):
self.data = paddle.uniform((32, 3, 10, 10), dtype='float32')
self.label = paddle.randint(0, 10, shape=[32], dtype='int64')
self.data = paddle.uniform((8, 3, 10, 10), dtype='float32')
self.label = paddle.randint(0, 10, shape=[8], dtype='int64')
def create_model(self, use_ipu=False):
return SimpleLayer(loss_op=self.loss_op,
......@@ -215,8 +215,8 @@ class TestWithoutIdentityLoss2(TestBase):
self.loss_op = paddle.fluid.layers.softmax_with_cross_entropy
def set_data_feed(self):
self.data = paddle.uniform((32, 3, 10, 10), dtype='float32')
self.label = paddle.randint(0, 10, shape=[32, 1], dtype='int64')
self.data = paddle.uniform((8, 3, 10, 10), dtype='float32')
self.label = paddle.randint(0, 10, shape=[8, 1], dtype='int64')
def create_model(self, use_ipu=False):
return SimpleLayer(loss_op=self.loss_op,
......@@ -231,8 +231,41 @@ class TestWithoutIdentityLoss3(TestBase):
self.loss_op = partial(paddle.fluid.layers.kldiv_loss, reduction="none")
def set_data_feed(self):
self.data = paddle.uniform((32, 3, 10, 10), dtype='float32')
self.label = paddle.rand(shape=[32, 81], dtype='float32')
self.data = paddle.uniform((8, 3, 10, 10), dtype='float32')
self.label = paddle.rand(shape=[8, 81], dtype='float32')
def create_model(self, use_ipu=False):
return SimpleLayer(loss_op=self.loss_op,
use_softmax=True,
use_reduction=True,
use_identity_loss=False)
class TestWithoutIdentityLoss4(TestBase):
def set_op_attrs(self):
self.loss_op = paddle.nn.functional.binary_cross_entropy
def set_data_feed(self):
self.data = paddle.uniform((8, 3, 10, 10), dtype='float32')
self.label = paddle.rand(shape=[8, 81], dtype='float32')
def create_model(self, use_ipu=False):
return SimpleLayer(loss_op=self.loss_op,
use_softmax=True,
use_reduction=False,
use_identity_loss=False)
class TestWithoutIdentityLoss5(TestBase):
def set_op_attrs(self):
self.loss_op = paddle.fluid.layers.sigmoid_cross_entropy_with_logits
def set_data_feed(self):
self.data = paddle.uniform((8, 3, 10, 10), dtype='float32')
self.label = paddle.randint(0, 10, shape=[8, 81],
dtype='int64').astype('float32')
def create_model(self, use_ipu=False):
return SimpleLayer(loss_op=self.loss_op,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
self.set_op_attrs()
def set_data_feed(self):
label = np.random.uniform(size=[3, 1])
left = np.random.uniform(size=[3, 1])
right = np.random.uniform(size=[3, 1])
self.feed_fp32 = {
"label": label.astype(np.float32),
"left": left.astype(np.float32),
"right": right.astype(np.float32),
}
self.feed_fp16 = {
"label": label.astype(np.float16),
"left": left.astype(np.float16),
"right": right.astype(np.float16),
}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def set_op_attrs(self):
self.attrs = {
'margin': 0.1,
}
@IPUOpTest.static_graph
def build_model(self, on_ipu):
label = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype="float32")
left = paddle.static.data(name=self.feed_list[1],
shape=self.feed_shape[1],
dtype='float32')
right = paddle.static.data(name=self.feed_list[2],
shape=self.feed_shape[2],
dtype='float32')
out = paddle.fluid.layers.margin_rank_loss(label, left, right)
self.fetch_list = [out.name]
def run_model(self, exec_mode):
self.run_op_test(exec_mode)
def test(self):
for m in IPUOpTest.ExecutionMode:
if not self.skip_mode(m):
self.build_model(self.is_ipu_mode(m))
self.run_model(m)
self.check()
class TestCase1(TestBase):
def set_op_attrs(self):
self.attrs = {
'margin': 0.5,
}
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
def set_data_feed(self):
label = np.random.uniform(size=[3, 1])
left = np.random.uniform(size=[3, 1])
right = np.random.uniform(size=[3, 1])
self.feed_fp32 = {
"label": label.astype(np.float32),
"left": left.astype(np.float32),
"right": right.astype(np.float32),
}
self.feed_fp16 = {
"label": label.astype(np.float16),
"left": left.astype(np.float16),
"right": right.astype(np.float16),
}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
@IPUOpTest.static_graph
def build_model(self, on_ipu):
label = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype="float32")
left = paddle.static.data(name=self.feed_list[1],
shape=self.feed_shape[1],
dtype='float32')
right = paddle.static.data(name=self.feed_list[2],
shape=self.feed_shape[2],
dtype='float32')
out = paddle.fluid.layers.rank_loss(label, left, right)
self.fetch_list = [out.name]
def run_model(self, exec_mode):
self.run_op_test(exec_mode)
def test(self):
for m in IPUOpTest.ExecutionMode:
if not self.skip_mode(m):
self.build_model(self.is_ipu_mode(m))
self.run_model(m)
self.check()
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.static
from paddle.fluid.tests.unittests.ipu.op_test_ipu import IPUOpTest
import paddle.nn.functional as F
class TestBase(IPUOpTest):
def setUp(self):
self.set_atol()
self.set_training()
self.set_data_feed()
self.set_feed_attr()
self.set_op_attrs()
def set_data_feed(self):
x = np.random.uniform(size=[10])
label = np.arange(10).reshape([10])
self.feed_fp32 = {
"x": x.astype(np.float32),
"label": label.astype(np.float32)
}
self.feed_fp16 = {
"x": x.astype(np.float16),
"label": label.astype(np.float16)
}
def set_feed_attr(self):
self.feed_shape = [x.shape for x in self.feed_fp32.values()]
self.feed_list = list(self.feed_fp32.keys())
def set_op_attrs(self):
self.attrs = {
'ignore_index': -100,
}
@IPUOpTest.static_graph
def build_model(self, on_ipu):
x = paddle.static.data(name=self.feed_list[0],
shape=self.feed_shape[0],
dtype="float32")
label = paddle.static.data(name=self.feed_list[1],
shape=self.feed_shape[1],
dtype='float32')
out = paddle.fluid.layers.sigmoid_cross_entropy_with_logits(
x, label, **self.attrs)
self.fetch_list = [out.name]
def run_model(self, exec_mode):
self.run_op_test(exec_mode)
def test(self):
for m in IPUOpTest.ExecutionMode:
if not self.skip_mode(m):
self.build_model(self.is_ipu_mode(m))
self.run_model(m)
self.check()
class TestCase1(TestBase):
def set_op_attrs(self):
self.attrs = {
'ignore_index': 1,
}
class TestCase2(TestBase):
def set_atol(self):
# epsilon is added when normalize is True, use larger atol.
self.atol = 1e-6
self.rtol = 1e-5
self.atol_fp16 = 1e-3
self.rtol_fp16 = 1e-3
def set_op_attrs(self):
self.attrs = {
'ignore_index': 1,
'normalize': True,
}
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册