未验证 提交 bf4a4636 编写于 作者: Z Zhong Hui 提交者: GitHub

change to use bce_loss op, add shape check for bce_loss

change to use bce_loss op, add numel check for bce_loss.
上级 0e816260
......@@ -32,22 +32,29 @@ class BCELossOp : public framework::OperatorWithKernel {
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "BCELoss");
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(
x_dims.size(), label_dims.size(),
platform::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same shape."));
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(label_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
auto labels_dims = ctx->GetInputDim("Label");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
platform::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same rank."
"But received: the rank of Input(X) is [%d], "
"the rank of Input(Label) is [%d].",
rank, labels_dims.size()));
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
x_dims.size(), label_dims.size(),
platform::errors::InvalidArgument(
"ShapeError: Input(X) and Input(Label) shall have the same shape "
"But received: the shape of Input(X) is [%s], the shape of "
"Input(Label) is [%s].",
x_dims, label_dims));
PADDLE_ENFORCE_EQ(x_dims, labels_dims,
platform::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same "
"shape. But received: the shape of Input(X) is "
"[%s], the shape of Input(Label) is [%s].",
x_dims, labels_dims));
}
ctx->ShareDim("X", "Out");
......@@ -76,20 +83,31 @@ class BCELossGradOp : public framework::OperatorWithKernel {
framework::GradVarName("X"), "BCELossGrad");
auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(dout_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(x_dims, labels_dims,
platform::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same "
"shape. But received: the shape of Input(X) is "
"[%s], the shape of Input(Label) is [%s].",
x_dims, labels_dims));
PADDLE_ENFORCE_EQ(x_dims, dout_dims,
platform::errors::InvalidArgument(
"ShapeError:The Input(X) and Input(Out@Grad) "
"should have the same "
"shape, But received: the shape of Input(X) is "
"[%s], the shape of "
"Input(Out@GRAD) is [%s].",
"Input(X) and Input(Out@Grad) shall have the same "
"shape. But received: the shape of Input(X) is "
"[%s], the shape of Input(Out@Grad) is [%s].",
x_dims, dout_dims));
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X"));
}
......
......@@ -67,7 +67,8 @@ class BCELossCUDAKernel : public framework::OpKernel<T> {
auto x_data = x->data<T>();
auto out_data = out->mutable_data<T>(ctx.GetPlace());
int x_numel = x->numel();
auto x_numel = x->numel();
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(x_numel, ctx);
......@@ -75,7 +76,7 @@ class BCELossCUDAKernel : public framework::OpKernel<T> {
framework::TensorCopy(*x, platform::CPUPlace(), &x_cpu);
T* x_cpu_data = x_cpu.data<T>();
for (int i = 0; i < x_numel; ++i) {
for (int64_t i = 0; i < x_numel; ++i) {
PADDLE_ENFORCE_GE(
x_cpu_data[i], static_cast<T>(0),
platform::errors::InvalidArgument(
......
......@@ -34,11 +34,11 @@ class BCELossOpKernel : public framework::OpKernel<T> {
auto x_data = x->data<T>();
auto label_data = labels->data<T>();
auto out_data = out->mutable_data<T>(ctx.GetPlace());
int x_numel = x->numel();
auto x_numel = x->numel();
// out = -(label * ln(x) + (1 - label) * ln(1 - x)) = (label - 1) * ln(1 -
// x) - label * ln(x)
for (int i = 0; i < x_numel; ++i) {
for (int64_t i = 0; i < x_numel; ++i) {
PADDLE_ENFORCE_GE(
x_data[i], static_cast<T>(0),
platform::errors::InvalidArgument(
......
......@@ -189,20 +189,6 @@ class TestBCELoss(unittest.TestCase):
self.assertTrue(np.allclose(static_functional, dy_functional))
self.assertTrue(np.allclose(dy_functional, expected))
def test_BCELoss_boardcast(self):
input_np = np.random.uniform(
0.1, 0.8, size=(2, 3, 4, 10)).astype(np.float64)
label_np = np.random.randint(0, 2, size=(3, 4, 10)).astype(np.float64)
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
static_result = test_static_layer(place, input_np, label_np)
dy_result = test_dygraph_layer(place, input_np, label_np)
expected = calc_bceloss(input_np, label_np)
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
def test_BCELoss_error(self):
paddle.disable_static()
self.assertRaises(
......
......@@ -157,19 +157,7 @@ def binary_cross_entropy(input, label, weight=None, reduction='mean',
reduction)
if in_dygraph_mode():
one = _varbase_creator(dtype=input.dtype)
core.ops.fill_constant(one, 'value',
float(1.0), 'force_cpu', False, 'dtype',
one.dtype, 'str_value', '1.0', 'shape', [1])
one.stop_gradient = True
label_minus = core.ops.elementwise_sub(label, one)
input_minus = core.ops.elementwise_sub(one, input)
input_minus_log = core.ops.log(input_minus)
input_log = core.ops.log(input)
loss_1 = core.ops.elementwise_mul(label_minus, input_minus_log)
loss_2 = core.ops.elementwise_mul(label, input_log)
out = core.ops.elementwise_sub(loss_1, loss_2)
out = core.ops.bce_loss(input, label)
if weight is not None:
out = core.ops.elementwise_mul(out, weight, 'axis', -1)
......@@ -187,17 +175,16 @@ def binary_cross_entropy(input, label, weight=None, reduction='mean',
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'binary_cross_entropy')
one = paddle.fill_constant(shape=[1], value=1.0, dtype=input.dtype)
one.stop_gradient = True
label_minus = paddle.elementwise_sub(label, one)
input_minus = paddle.elementwise_sub(one, input)
input_minus_log = paddle.log(input_minus)
input_log = paddle.log(input)
loss_1 = paddle.multiply(label_minus, input_minus_log)
loss_2 = paddle.multiply(label, input_log)
sub_name = name if weight is None and reduction is 'none' else None
out = paddle.elementwise_sub(loss_1, loss_2, name=sub_name)
helper = LayerHelper("binary_cross_entropy", name=sub_name)
out = helper.create_variable_for_type_inference(dtype=input.dtype)
helper.append_op(
type='bce_loss',
inputs={
'X': [input],
'Label': [label],
},
outputs={'Out': [out]})
if weight is not None:
if isinstance(weight, paddle.framework.Variable):
......@@ -952,9 +939,9 @@ def ctc_loss(log_probs,
reduction='mean'):
"""
An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation
An operator integrating the open source Warp-CTC library (https://github.com/baidu-research/warp-ctc)
to compute Connectionist Temporal Classification (CTC) loss.
It can be aliased as softmax with CTC, since a native softmax activation
is interated to the Warp-CTC library to normalize values for each row of the input tensor.
Parameters:
......@@ -967,7 +954,7 @@ def ctc_loss(log_probs,
Returns:
Tensor, The Connectionist Temporal Classification (CTC) loss between ``log_probs`` and ``labels``. If attr:`reduction` is ``'none'``, the shape of loss is [batch_size], otherwise, the shape of loss is [1]. Data type is the same as ``log_probs``.
Examples:
.. code-block:: python
......@@ -1012,18 +999,18 @@ def ctc_loss(log_probs,
input_lengths = paddle.to_tensor(input_lengths)
label_lengths = paddle.to_tensor(label_lengths)
loss = F.ctc_loss(log_probs, labels,
input_lengths,
label_lengths,
blank=0,
loss = F.ctc_loss(log_probs, labels,
input_lengths,
label_lengths,
blank=0,
reduction='none')
print(loss.numpy()) #[3.9179852 2.9076521]
loss = F.ctc_loss(log_probs, labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean')
loss = F.ctc_loss(log_probs, labels,
input_lengths,
label_lengths,
blank=0,
reduction='mean')
print(loss.numpy()) #[1.1376063]
"""
......@@ -1071,8 +1058,8 @@ def cross_entropy(input,
Parameters:
input (Tensor): Input tensor, the data type is float32, float64. Shape is
(N, C), where C is number of classes, and if shape is more than 2D, this
is (N, C, D1, D2,..., Dk), k >= 1.
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
is (N, C, D1, D2,..., Dk), k >= 1.
label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
(N, D1, D2,..., Dk), k >= 1.
weight (Tensor, optional): Weight tensor, a manual rescaling weight given
......@@ -1105,7 +1092,7 @@ def cross_entropy(input,
weight = paddle.to_tensor(weight_data)
loss = paddle.nn.functional.cross_entropy(input=input, label=label, weight=weight)
print(loss.numpy())
"""
if not in_dygraph_mode():
fluid.data_feeder.check_variable_and_dtype(
......@@ -1124,7 +1111,7 @@ def cross_entropy(input,
raise ValueError(
"The weight' is not a Variable, please convert to Variable.")
#step 2. nll_loss
#step 2. nll_loss
input = log_softmax_out
helper = LayerHelper('nll_loss', **locals())
dtype = helper.input_dtype(input)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册