未验证 提交 bb42d870 编写于 作者: S Siming Dai 提交者: GitHub

[BugFix] Fix bug for binary_cross_entropy_with_logits loss (#54869)

* add pos_weight in kernel

* fix unittest

* fix xpu

* fix bce unittest, change infermeta order
上级 57da105c
......@@ -1830,8 +1830,8 @@
data_type : x
- backward_op : sigmoid_cross_entropy_with_logits_grad
forward : sigmoid_cross_entropy_with_logits (Tensor x, Tensor label, bool normalize=false, int ignore_index=-100) -> Tensor(out)
args : (Tensor x, Tensor label, Tensor out_grad, bool normalize, int ignore_index)
forward : sigmoid_cross_entropy_with_logits (Tensor x, Tensor label, Tensor pos_weight, bool normalize=false, int ignore_index=-100) -> Tensor(out)
args : (Tensor x, Tensor label, Tensor pos_weight, Tensor out_grad, bool normalize, int ignore_index)
output : Tensor(x_grad)
infer_meta :
func : UnchangedInferMeta
......@@ -1839,6 +1839,7 @@
kernel :
func : sigmoid_cross_entropy_with_logits_grad
inplace : (out_grad -> x_grad)
optional : pos_weight
- backward_op : sigmoid_double_grad
forward : sigmoid_grad (Tensor out, Tensor fwd_grad_out) -> Tensor(grad_x)
......
......@@ -2106,7 +2106,7 @@
backward : sigmoid_grad
- op : sigmoid_cross_entropy_with_logits
args : (Tensor x, Tensor label, bool normalize=false, int ignore_index=-100)
args : (Tensor x, Tensor label, Tensor pos_weight, bool normalize=false, int ignore_index=-100)
output : Tensor
infer_meta :
func : SigmoidCrossEntropyWithLogitsInferMeta
......@@ -2114,6 +2114,7 @@
func : sigmoid_cross_entropy_with_logits
inplace : (x -> out)
backward : sigmoid_cross_entropy_with_logits_grad
optional : pos_weight
- op : sign
args : (Tensor x)
......@@ -2514,7 +2515,7 @@
func : WeightedSampleNeighborsInferMeta
kernel :
func : weighted_sample_neighbors
optional: eids
optional : eids
- op : where
args : (Tensor condition, Tensor x, Tensor y)
......
......@@ -2672,47 +2672,6 @@ void SegmentPoolInferMeta(const MetaTensor& x,
}
}
void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config) {
auto x_dims = x.dims();
auto labels_dims = label.dims();
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank,
labels_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank,
labels_dims.size()));
bool check = true;
if ((!config.is_runtime) &&
(phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
phi::slice_ddim(x_dims, 0, rank),
phi::slice_ddim(labels_dims, 0, rank),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension. But received: the shape of "
"Input(X) is [%s], the shape of Input(Label) is [%s].",
x_dims,
labels_dims));
}
out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
void TakeAlongAxisInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
......
......@@ -417,13 +417,6 @@ void SegmentPoolInferMeta(const MetaTensor& x,
MetaTensor* summed_ids,
MetaConfig config = MetaConfig());
void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config = MetaConfig());
void TakeAlongAxisInferMeta(const MetaTensor& x,
const MetaTensor& index,
int axis,
......
......@@ -2850,6 +2850,61 @@ void SgdInferMeta(const MetaTensor& param,
}
}
void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
const MetaTensor& pos_weight,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config) {
auto x_dims = x.dims();
auto labels_dims = label.dims();
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank,
labels_dims.size(),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank,
labels_dims.size()));
bool check = true;
if ((!config.is_runtime) &&
(phi::product(x_dims) <= 0 || phi::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
phi::slice_ddim(x_dims, 0, rank),
phi::slice_ddim(labels_dims, 0, rank),
phi::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension. But received: the shape of "
"Input(X) is [%s], the shape of Input(Label) is [%s].",
x_dims,
labels_dims));
if (pos_weight) {
auto weight_dims = pos_weight.dims();
PADDLE_ENFORCE_EQ(
phi::slice_ddim(weight_dims, 0, rank),
phi::slice_ddim(labels_dims, 0, rank),
phi::errors::InvalidArgument(
"Input(pos_weight) and Input(Label) shall have the same shape "
"But received: the shape of Input(PosWeight) is [%s], "
"the shape of Input(Label) is [%s].",
weight_dims,
labels_dims));
}
}
out->set_dims(x_dims);
out->set_dtype(x.dtype());
out->share_lod(x);
}
void SendUERecvInferMeta(const MetaTensor& x,
const MetaTensor& y,
const MetaTensor& src_index,
......@@ -3489,5 +3544,6 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count->set_dims({-1});
out_count->set_dtype(DataType::INT32);
}
} // namespace phi
PD_REGISTER_INFER_META_FN(batch_norm_infer, phi::BatchNormInferInferMeta);
......@@ -542,6 +542,14 @@ void SgdInferMeta(const MetaTensor& param,
MetaTensor* param_out,
MetaTensor* master_param_out);
void SigmoidCrossEntropyWithLogitsInferMeta(const MetaTensor& x,
const MetaTensor& label,
const MetaTensor& pos_weight,
bool normalize,
int ignore_index,
MetaTensor* out,
MetaConfig config = MetaConfig());
void StackInferMeta(const std::vector<const MetaTensor*>& x,
int axis,
MetaTensor* out,
......
......@@ -20,9 +20,11 @@
namespace phi {
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsGradKernel(const Context& dev_ctx,
void SigmoidCrossEntropyWithLogitsGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& pos_weight,
const DenseTensor& out_grad,
bool normalize,
int ignore_index,
......@@ -33,15 +35,20 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context& dev_ctx,
auto x_data = x.data<T>();
auto label_data = label.data<T>();
auto dout_data = out_grad.data<T>();
auto pos_weight_data =
(pos_weight.get_ptr() == nullptr ? nullptr
: pos_weight.get_ptr()->data<T>());
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
T dout = dout_data[idx];
T pos_weight_idx = pos_weight_data == nullptr ? 1 : pos_weight_data[idx];
if (static_cast<int>(label) == ignore_index) {
dx_data[idx] = static_cast<T>(0.);
} else {
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
T diff = simoid_x - label;
T diff = simoid_x * pos_weight_idx - label;
dx_data[idx] = dout * diff;
}
}
......
......@@ -23,9 +23,11 @@
namespace phi {
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsKernel(const Context& dev_ctx,
void SigmoidCrossEntropyWithLogitsKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& pos_weight,
bool normalize,
int ignore_index,
DenseTensor* out) {
......@@ -33,16 +35,21 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context& dev_ctx,
int limit = out->numel();
auto x_data = x.data<T>();
auto label_data = label.data<T>();
auto pos_weight_data =
(pos_weight.get_ptr() == nullptr ? nullptr
: pos_weight.get_ptr()->data<T>());
for (int idx = 0; idx < limit; ++idx) {
T x = x_data[idx];
T label = label_data[idx];
if (static_cast<int>(label) == ignore_index) {
out_data[idx] = static_cast<T>(0.);
} else {
T pos_weight_idx = pos_weight_data == nullptr ? 1 : pos_weight_data[idx];
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 = std::log(static_cast<T>(1) + std::exp(-std::abs(x)));
out_data[idx] = term1 - term2 + term3;
out_data[idx] = term1 - term2 + term3 * pos_weight_idx;
}
}
......
......@@ -52,10 +52,46 @@ struct SigmoidBwdFunctor {
}
};
template <typename T>
struct SigmoidBwdPosWeightFunctor {
T ignore_index_;
T eps = static_cast<T>(1e-5);
HOSTDEVICE inline SigmoidBwdPosWeightFunctor(const T ignore_index)
: ignore_index_(ignore_index) {}
HOSTDEVICE inline phi::Array<T, 2> operator()(const T x,
const T label,
const T pos_weight,
const T dout) {
T counts;
T dx_data;
T diff = label - static_cast<T>(ignore_index_);
if ((diff > -eps) && (diff < eps)) {
dx_data = static_cast<T>(0.);
counts = 0;
} else {
T simoid_x =
static_cast<T>(1) / (static_cast<T>(1) + phi::funcs::real_exp(-x));
T diff = simoid_x * pos_weight - label;
dx_data = dout * diff;
counts = 1;
}
phi::Array<T, 2> outs;
outs[0] = dx_data;
outs[1] = counts;
return outs;
}
};
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
void SigmoidCrossEntropyWithLogitsGradKernel(
const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &label,
const paddle::optional<DenseTensor> &pos_weight,
const DenseTensor &out_grad,
bool normalize,
int ignore_index,
......@@ -70,11 +106,19 @@ void SigmoidCrossEntropyWithLogitsGradKernel(const Context &dev_ctx,
dev_ctx.template Alloc<T>(counts_tensor);
counts_tensor->Resize(in_grad->dims());
std::vector<const DenseTensor *> ins = {&x, &label, &out_grad};
std::vector<DenseTensor *> outs = {in_grad, counts_tensor};
if (pos_weight.get_ptr() == nullptr) {
std::vector<const DenseTensor *> ins = {&x, &label, &out_grad};
auto functor = SigmoidBwdFunctor<T>(ignore_index);
phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
dev_ctx, ins, &outs, functor);
} else {
std::vector<const DenseTensor *> ins = {
&x, &label, pos_weight.get_ptr(), &out_grad};
auto functor = SigmoidBwdPosWeightFunctor<T>(ignore_index);
phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
dev_ctx, ins, &outs, functor);
}
if (normalize) {
DenseTensor *norm_tensor = new DenseTensor();
norm_tensor->Resize({sizeof(T)});
......
......@@ -52,10 +52,49 @@ struct SigmoidFwdFunctor {
}
};
template <typename T>
struct SigmoidFwdPosWeightFunctor {
T ignore_index_;
T eps = static_cast<T>(1e-5);
HOSTDEVICE inline SigmoidFwdPosWeightFunctor(const T ignore_index)
: ignore_index_(ignore_index) {}
HOSTDEVICE inline phi::Array<T, 2> operator()(const T x,
const T label,
T pos_weight) {
T counts;
T out_data;
T diff = label - static_cast<T>(ignore_index_);
if ((diff > -eps) && (diff < eps)) {
out_data = static_cast<T>(0.);
counts = 0;
} else {
T term1 = (x > 0) ? x : 0;
T term2 = x * label;
T term3 =
phi::funcs::real_log(static_cast<T>(1) +
phi::funcs::real_exp(static_cast<T>(-abs(x)))) *
pos_weight;
out_data = term1 - term2 + term3;
counts = 1;
}
phi::Array<T, 2> outs;
outs[0] = out_data;
outs[1] = counts;
return outs;
}
};
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx,
void SigmoidCrossEntropyWithLogitsKernel(
const Context &dev_ctx,
const DenseTensor &x,
const DenseTensor &label,
const paddle::optional<DenseTensor> &pos_weight,
bool normalize,
int ignore_index,
DenseTensor *out) {
......@@ -69,11 +108,19 @@ void SigmoidCrossEntropyWithLogitsKernel(const Context &dev_ctx,
dev_ctx.template Alloc<T>(counts_tensor);
counts_tensor->Resize(out->dims());
std::vector<const DenseTensor *> ins = {&x, &label};
std::vector<DenseTensor *> outs = {out, counts_tensor};
if (pos_weight.get_ptr() == nullptr) {
std::vector<const DenseTensor *> ins = {&x, &label};
auto functor = SigmoidFwdFunctor<T>(ignore_index);
phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
dev_ctx, ins, &outs, functor);
} else {
std::vector<const DenseTensor *> ins = {&x, &label, pos_weight.get_ptr()};
auto functor = SigmoidFwdPosWeightFunctor<T>(ignore_index);
phi::funcs::ElementwiseKernel<T, decltype(functor), 2>(
dev_ctx, ins, &outs, functor);
}
if (normalize) {
DenseTensor *norm_tensor = new DenseTensor();
norm_tensor->Resize({sizeof(T)});
......
......@@ -19,9 +19,11 @@
namespace phi {
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsGradKernel(const Context& dev_ctx,
void SigmoidCrossEntropyWithLogitsGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& pos_weight,
const DenseTensor& out_grad,
bool normalize,
int ignore_index,
......
......@@ -19,9 +19,11 @@
namespace phi {
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsKernel(const Context& dev_ctx,
void SigmoidCrossEntropyWithLogitsKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& pos_weight,
bool normalize,
int ignore_index,
DenseTensor* out);
......
......@@ -25,9 +25,11 @@
namespace phi {
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsGradKernel(const Context& dev_ctx,
void SigmoidCrossEntropyWithLogitsGradKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& pos_weight,
const DenseTensor& out_grad,
bool normalize,
int ignore_index,
......
......@@ -25,9 +25,11 @@
namespace phi {
template <typename T, typename Context>
void SigmoidCrossEntropyWithLogitsKernel(const Context& dev_ctx,
void SigmoidCrossEntropyWithLogitsKernel(
const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& label,
const paddle::optional<DenseTensor>& pos_weight,
bool normalize,
int ignore_index,
DenseTensor* out) {
......
......@@ -791,14 +791,15 @@ def binary_cross_entropy_with_logits(
logit.dtype,
_current_expected_place(),
)
out = _C_ops.sigmoid_cross_entropy_with_logits(
logit, label, False, -100
)
if pos_weight is not None:
log_weight = _C_ops.add(
pos_weight = _C_ops.add(
_C_ops.multiply(label, _C_ops.subtract(pos_weight, one)), one
)
out = _C_ops.multiply(out, log_weight)
out = _C_ops.sigmoid_cross_entropy_with_logits(
logit, label, pos_weight, False, -100
)
if weight is not None:
out = _C_ops.multiply(out, weight)
......@@ -829,13 +830,6 @@ def binary_cross_entropy_with_logits(
out = helper.create_variable_for_type_inference(dtype=logit.dtype)
helper.append_op(
type="sigmoid_cross_entropy_with_logits",
inputs={"X": logit, "Label": label},
attrs={"ignore_index": kIgnoreIndex, 'normalize': False},
outputs={"Out": out},
)
one = paddle.full(shape=[1], fill_value=1.0, dtype=logit.dtype)
if pos_weight is not None:
check_variable_and_dtype(
......@@ -844,13 +838,16 @@ def binary_cross_entropy_with_logits(
['float32', 'float64'],
'binary_cross_entropy_with_logits',
)
log_weight = paddle.add(
pos_weight = paddle.add(
paddle.multiply(label, paddle.subtract(pos_weight, one)), one
)
pos_weight_name = (
name if reduction == 'none' and weight is None else None
helper.append_op(
type="sigmoid_cross_entropy_with_logits",
inputs={"X": logit, "Label": label, "pos_weight": pos_weight},
attrs={"ignore_index": kIgnoreIndex, 'normalize': False},
outputs={"Out": out},
)
out = paddle.multiply(out, log_weight, name=pos_weight_name)
if weight is not None:
check_variable_and_dtype(
......@@ -3061,7 +3058,7 @@ def sigmoid_focal_loss(
one = _C_ops.full(logit.shape, float(1.0), logit.dtype, place)
loss = _C_ops.sigmoid_cross_entropy_with_logits(
logit, label, False, -100
logit, label, None, False, -100
)
pred = _C_ops.sigmoid(logit)
......@@ -3108,7 +3105,7 @@ def sigmoid_focal_loss(
if reduction == 'none' and normalizer is None:
bce_name = name
loss = paddle.nn.functional.binary_cross_entropy_with_logits(
logit, label, reduction='none', name=bce_name
logit, label, None, reduction='none', name=bce_name
)
pred = paddle.nn.functional.sigmoid(logit)
......
......@@ -114,13 +114,16 @@ def test_dygraph(
def calc_bce_with_logits_loss(
logit_np, label_np, reduction='mean', weight_np=None, pos_weight=None
):
expected = (
np.maximum(logit_np, 0)
- logit_np * label_np
+ np.log(1 + np.exp(-np.abs(logit_np)))
)
item1 = np.maximum(logit_np, 0)
item2 = logit_np * label_np
item3 = np.log(1 + np.exp(-np.abs(logit_np)))
if pos_weight is not None:
expected = expected * ((pos_weight - 1) * label_np + 1)
pos_weight = (pos_weight - 1) * label_np + 1
expected = item1 - item2 + item3 * pos_weight
else:
expected = item1 - item2 + item3
if weight_np is not None:
expected = weight_np * expected
......
......@@ -23,9 +23,11 @@ from paddle import fluid
from paddle.fluid import Program, program_guard
def loss_wrapper(logit, label, normalize=False, ignore_index=-100):
def loss_wrapper(
logit, label, pos_weight=None, normalize=False, ignore_index=-100
):
out = paddle._C_ops.sigmoid_cross_entropy_with_logits(
logit, label, normalize, ignore_index
logit, label, pos_weight, normalize, ignore_index
)
return out
......@@ -137,6 +139,44 @@ class TestSigmoidCrossEntropyWithLogitsOp3(OpTest):
self.check_grad(['X'], 'Out')
class TestSigmoidCrossEntropyWithLogitsOp4(OpTest):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
self.python_api = loss_wrapper
batch_size = 64
num_classes = 20
x = logit(
np.random.uniform(0, 1, (batch_size, num_classes)).astype("float64")
)
label = np.random.uniform(0, 1, (batch_size, num_classes)).astype(
"float64"
)
pos_weight = np.random.uniform(0, 1, (batch_size, num_classes)).astype(
"float64"
)
self.inputs = {
'X': x,
'Label': label,
'pos_weight': pos_weight,
}
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
term1 = np.maximum(self.inputs['X'], 0)
term2 = self.inputs['X'] * self.inputs['Label']
term3 = np.log(1 + np.exp(-1 * np.abs(self.inputs['X']))) * pos_weight
self.outputs = {'Out': term1 - term2 + term3}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSigmoidCrossEntropyWithNorm(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册