提交 79cec531 编写于 作者: J jerrywgz

add ignore index for sigmoid cross entropy with logits op, test=develop

上级 f17b05d4
......@@ -100,6 +100,11 @@ class SigmoidCrossEntropyWithLogitsOpMaker
AddOutput("Out",
"(Tensor, default Tensor<float>), a 2-D tensor with shape N x D "
" of elementwise logistic losses.");
AddAttr<int>(
"ignore_index",
"(int, default -1), Specifies a target value that is ignored and"
"does not contribute to the input gradient.")
.SetDefault(-1);
AddComment(R"DOC(
SigmoidCrossEntropyWithLogits Operator.
......
......@@ -15,33 +15,82 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/hostdevice.h"
#include "paddle/legacy/utils/Logging.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename T>
struct SigmoidCrossEntropyWithLogitsForward {
// EIGEN_EMPTY_STRUCT_CTOR(SigmoidCrossEntropyWithLogitsForward)
HOSTDEVICE SigmoidCrossEntropyWithLogitsForward(const int &ignore_index)
: ignore_index(ignore_index) {}
HOSTDEVICE T operator()(const T &x, const T &label) const {
if (static_cast<int>(label) == ignore_index) {
return static_cast<T>(0.);
}
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = std::log(static_cast<T>(1) + std::exp(-(std::abs(x))));
return term1 - term2 + term3;
}
int ignore_index;
};
template <typename T>
struct SigmoidCrossEntropyWithLogitsBackward {
// EIGEN_EMPTY_STRUCT_CTOR(SigmoidCrossEntropyWithLogitsForward)
HOSTDEVICE SigmoidCrossEntropyWithLogitsBackward(const int &ignore_index)
: ignore_index(ignore_index) {}
HOSTDEVICE T operator()(const T &x, const T &label) const {
if (static_cast<int>(label) == ignore_index) {
return static_cast<T>(0.);
}
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
return simoid_x - label;
}
int ignore_index;
};
// Out = max(X, 0) - X * Labels + log(1 + exp(-abs(X)))
template <typename DeviceContext, typename T>
class SigmoidCrossEntropyWithLogitsKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X");
const framework::Tensor *Labels = context.Input<framework::Tensor>("Label");
framework::Tensor *Out = context.Output<framework::Tensor>("Out");
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
Tensor *Out = context.Output<Tensor>("Out");
Out->mutable_data<T>(context.GetPlace());
int ignore_index = context.Attr<int>("ignore_index");
auto x = framework::EigenVector<T>::Flatten(*X);
auto labels = framework::EigenVector<T>::Flatten(*Labels);
auto out = framework::EigenVector<T>::Flatten(*Out);
auto x = EigenVector<T>::Flatten(*X);
auto labels = EigenVector<T>::Flatten(*Labels);
auto out = EigenVector<T>::Flatten(*Out);
auto &place = *context.device_context<DeviceContext>().eigen_device();
out.device(place) = x.binaryExpr(
labels, SigmoidCrossEntropyWithLogitsForward<T>(ignore_index));
// term1 = max(x, 0)
auto term1 = x.cwiseMax(static_cast<T>(0));
// auto term1 = x.cwiseMax(static_cast<T>(0));
// term2 = x * labels
auto term2 = x * labels;
// auto term2 = x * labels;
// term3 = log(1 + exp(-abs(x)))
auto term3 = (static_cast<T>(1) + (-(x.abs())).exp()).log();
// auto term3 = (static_cast<T>(1) + (-(x.abs())).exp()).log();
out.device(place) = term1 - term2 + term3;
// out.device(place) = term1 - term2 + term3;
}
};
......@@ -50,23 +99,23 @@ template <typename DeviceContext, typename T>
class SigmoidCrossEntropyWithLogitsGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
const framework::Tensor *X = context.Input<framework::Tensor>("X");
const framework::Tensor *Labels = context.Input<framework::Tensor>("Label");
const framework::Tensor *dOut =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
framework::Tensor *dX =
context.Output<framework::Tensor>(framework::GradVarName("X"));
const Tensor *X = context.Input<Tensor>("X");
const Tensor *Labels = context.Input<Tensor>("Label");
const Tensor *dOut = context.Input<Tensor>(framework::GradVarName("Out"));
Tensor *dX = context.Output<Tensor>(framework::GradVarName("X"));
dX->mutable_data<T>(context.GetPlace());
auto x = framework::EigenVector<T>::Flatten(*X);
auto labels = framework::EigenVector<T>::Flatten(*Labels);
auto dout = framework::EigenVector<T>::Flatten(*dOut);
auto dx = framework::EigenVector<T>::Flatten(*dX);
auto ignore_index = context.Attr<int>("ignore_index");
auto x = EigenVector<T>::Flatten(*X);
auto labels = EigenVector<T>::Flatten(*Labels);
auto dout = EigenVector<T>::Flatten(*dOut);
auto dx = EigenVector<T>::Flatten(*dX);
auto &place =
*context.template device_context<DeviceContext>().eigen_device();
auto sigmoid_x = static_cast<T>(1) / (static_cast<T>(1) + (-x).exp());
dx.device(place) = dout * (sigmoid_x - labels);
auto diff = x.binaryExpr(labels, SigmoidCrossEntropyWithLogitsBackward<T>(
static_cast<int>(ignore_index)));
dx.device(place) = dout * diff;
}
};
......
......@@ -7892,13 +7892,14 @@ def mul(x, y, x_num_col_dims=1, y_num_col_dims=1, name=None):
@templatedoc()
def sigmoid_cross_entropy_with_logits(x, label, name=None):
def sigmoid_cross_entropy_with_logits(x, label, ignore_index=-1, name=None):
"""
${comment}
Args:
x(${x_type}): ${x_comment}
label(${label_type}): ${label_comment}
ignore_index(&{ignore_index}): ${ignore_index_comment}
name(basestring|None): Name of the output.
Returns:
......@@ -7917,7 +7918,7 @@ def sigmoid_cross_entropy_with_logits(x, label, name=None):
type="sigmoid_cross_entropy_with_logits",
inputs={"X": x,
"Label": label},
attrs={},
attrs={"ignore_index": ignore_index},
outputs={"Out": out})
return out
......
......@@ -170,9 +170,10 @@ class TestBook(unittest.TestCase):
with program_guard(program):
dat = layers.data(name='data', shape=[10], dtype='float32')
lbl = layers.data(name='label', shape=[10], dtype='float32')
ignore_index = -1
self.assertIsNotNone(
layers.sigmoid_cross_entropy_with_logits(
x=dat, label=lbl))
x=dat, label=lbl, ignore_index=-1))
print(str(program))
def test_hsigmoid(self):
......
......@@ -56,6 +56,40 @@ class TestSigmoidCrossEntropyWithLogitsOp2(OpTest):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
batch_size = 64
num_classes = 20
ignore_index = -1
self.inputs = {
'X': logit(
np.random.uniform(0, 1, (batch_size, num_classes))
.astype("float32")),
'Label': np.random.randint(-1, 2, (batch_size, num_classes))
.astype("float32")
}
self.attrs = {'ignore_index': ignore_index, }
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
out = -term1 - term2
out[np.where(self.inputs['Label'] == ignore_index)] = 0
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSigmoidCrossEntropyWithLogitsOp3(OpTest):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
batch_size = 64
......@@ -85,3 +119,4 @@ class TestSigmoidCrossEntropyWithLogitsOp2(OpTest):
if __name__ == '__main__':
unittest.main()
np.random.seed(0)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册