未验证 提交 ab05cdc4 编写于 作者: C ceci3 提交者: GitHub

Add bce_loss op (#23388)

* add bce_loss

* fix mistake

* replace paddle_enforce,test=develop

* fix,test=develop

* update,test=develop

* remove duplicate,test=develop

* update,test=develop

* update error,test=develop

* update,test=develop

* fix unittest, test=develop

* update, test=develop
上级 faf284a9
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/bce_loss_op.h"
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class BCELossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should be not null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"), true,
platform::errors::InvalidArgument("Input(Label) should be not null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::InvalidArgument("Output(Out) should be not null."));
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(
x_dims.size(), label_dims.size(),
platform::errors::InvalidArgument(
"Input(X) and Input(Label) shall have the same shape."));
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(label_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_EQ(
x_dims.size(), label_dims.size(),
platform::errors::InvalidArgument(
"ShapeError: Input(X) and Input(Label) shall have the same shape "
"But received: the shape of Input(X) is [%s], the shape of "
"Input(Label) is [%s].",
x_dims, label_dims));
}
ctx->ShareDim("X", "Out");
ctx->ShareLoD("X", "Out");
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class BCELossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::InvalidArgument("Input(X) should be not null."));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Label"), true,
platform::errors::InvalidArgument("Input(Label) should be not null."));
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::InvalidArgument(
"Input(Out@GRAD) shoudl be not null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
platform::errors::InvalidArgument(
"Output(X@GRAD) should be not null."));
auto x_dims = ctx->GetInputDim("X");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(dout_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_EQ(x_dims, dout_dims,
platform::errors::InvalidArgument(
"ShapeError:The Input(X) and Input(Out@Grad) "
"should have the same "
"shape, But received: the shape of Input(X) is "
"[%s], the shape of "
"Input(Out@GRAD) is [%s].",
x_dims, dout_dims));
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD("X", framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class BCELossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor, default Tensor<float>), the input is a tensor of logits"
"computed by the previous operator, which is always the result of"
"a sigmoid operator. Input must between in 0 and 1.");
AddInput("Label",
"(Tensor, default Tensor<float>), have same shape with input"
"label should between in 0 and 1.");
AddOutput("Out",
"(Tensor, default Tensor<float>), have same shape with"
"input");
AddComment(R"DOC(
BinaryCrossEntropy operator.
This measures the element-wise probability error in classification tasks
in which each class is independent.
The logitstic loss is given as follows:
$$loss = -Label * \log(X) - (1 - Label) * \log(1 - X)$$
)DOC");
}
};
template <typename T>
class BCELossGradOpMaker : public framework::SingleGradOpMaker<T> {
public:
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("bce_loss_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Label", this->Input("Label"));
op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
// op->SetAttrMap(this->Attrs());
}
};
DECLARE_INPLACE_OP_INFERER(BCELossInplaceInferer, {"X", "Out"});
DECLARE_INPLACE_OP_INFERER(BCELossGradInplaceInferer,
{framework::GradVarName("Out"),
framework::GradVarName("X")});
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(bce_loss, ops::BCELossOp, ops::BCELossOpMaker,
ops::BCELossGradOpMaker<paddle::framework::OpDesc>,
ops::BCELossGradOpMaker<paddle::imperative::OpBase>,
ops::BCELossInplaceInferer);
REGISTER_OPERATOR(bce_loss_grad, ops::BCELossGradOp,
ops::BCELossGradInplaceInferer);
REGISTER_OP_CPU_KERNEL(
bce_loss, ops::BCELossOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::BCELossOpKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
bce_loss_grad,
ops::BCELossGradOpKernel<paddle::platform::CPUDeviceContext, float>,
ops::BCELossGradOpKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "cub/cub.cuh"
#include "paddle/fluid/operators/bce_loss_op.h"
#include "paddle/fluid/operators/math.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_launch_config.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename T>
__global__ void GPUBCELossForward(const T* x_data, const T* label_data,
T* out_data, const int in_numel) {
CUDA_1D_KERNEL_LOOP(i, in_numel) {
T x = x_data[i];
T label = label_data[i];
T one = static_cast<T>(1.);
T neg_100 = static_cast<T>(-100.);
T term1 = max(real_log(x), neg_100);
T term2 = max(real_log(one - x), neg_100);
out_data[i] = ((label - one) * term2) - (label * term1);
}
}
template <typename T>
__global__ void GPUBCELossBackward(const T* x_data, const T* label_data,
const T* dout_data, T* dx_data,
const int in_numel) {
CUDA_1D_KERNEL_LOOP(i, in_numel) {
T x = x_data[i];
T label = label_data[i];
T dout = dout_data[i];
T one = static_cast<T>(1.);
T eps = static_cast<T>(1e-12);
T term1 = max((one - x) * x, eps);
dx_data[i] = dout * (x - label) / term1;
}
}
template <typename DeviceContext, typename T>
class BCELossCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* out = ctx.Output<Tensor>("Out");
auto x_data = x->data<T>();
auto out_data = out->mutable_data<T>(ctx.GetPlace());
int x_numel = x->numel();
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(x_numel, ctx);
Tensor x_cpu;
framework::TensorCopy(*x, platform::CPUPlace(), &x_cpu);
T* x_cpu_data = x_cpu.data<T>();
for (int i = 0; i < x_numel; ++i) {
PADDLE_ENFORCE_GE(
x_cpu_data[i], static_cast<T>(0),
platform::errors::InvalidArgument(
"Illegal input, input must be greater than or equal to 0"));
PADDLE_ENFORCE_LE(
x_cpu_data[i], static_cast<T>(1),
platform::errors::InvalidArgument(
"Illegal input, input must be less than or equal to 1"));
}
auto& dev_ctx = ctx.cuda_device_context();
GPUBCELossForward<
T><<<config.blocks, config.threads, 0, dev_ctx.stream()>>>(
x_data, labels->data<T>(), out_data, x_numel);
}
};
template <typename DeviceContext, typename T>
class BCELossGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
int x_numel = x->numel();
platform::GpuLaunchConfig config =
platform::getGpuLaunchConfig(x_numel, ctx);
auto& dev_ctx = ctx.cuda_device_context();
GPUBCELossBackward<
T><<<config.blocks, config.threads, 0, dev_ctx.stream()>>>(
x->data<T>(), labels->data<T>(), dout->data<T>(), dx_data, x_numel);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
bce_loss,
ops::BCELossCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::BCELossCUDAKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
bce_loss_grad,
ops::BCELossGradCUDAKernel<paddle::platform::CUDADeviceContext, float>,
ops::BCELossGradCUDAKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <algorithm> // for max
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class BCELossOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* out = ctx.Output<Tensor>("Out");
auto x_data = x->data<T>();
auto label_data = labels->data<T>();
auto out_data = out->mutable_data<T>(ctx.GetPlace());
int x_numel = x->numel();
// out = -(label * ln(x) + (1 - label) * ln(1 - x)) = (label - 1) * ln(1 -
// x) - label * ln(x)
for (int i = 0; i < x_numel; ++i) {
PADDLE_ENFORCE_GE(
x_data[i], static_cast<T>(0),
platform::errors::InvalidArgument(
"Illegal input, input must be greater than or equal to 0"));
PADDLE_ENFORCE_LE(
x_data[i], static_cast<T>(1),
platform::errors::InvalidArgument(
"Illegal input, input must be less than or equal to 1"));
out_data[i] =
(label_data[i] - static_cast<T>(1)) *
std::max(real_log(static_cast<T>(1) - x_data[i]), (T)(-100)) -
label_data[i] * std::max(real_log(x_data[i]), (T)(-100));
}
}
};
template <typename DeviceContext, typename T>
class BCELossGradOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* labels = ctx.Input<Tensor>("Label");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto dx_data = dx->mutable_data<T>(ctx.GetPlace());
auto dout_data = dout->data<T>();
auto x_data = x->data<T>();
auto label_data = labels->data<T>();
int x_numel = x->numel();
// dx = dout * ((x - label)/(x - x^2))
for (int i = 0; i < x_numel; ++i) {
dx_data[i] =
dout_data[i] * ((x_data[i] - label_data[i]) /
std::max((static_cast<T>(1) - x_data[i]) * x_data[i],
static_cast<T>(1e-12)));
}
}
};
} // namespace operators
} // namespace paddle
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid as fluid
import numpy as np
import unittest
from op_test import OpTest
class TestBCELoss(unittest.TestCase):
def test_BCELoss(self):
input_np = np.random.random(size=(20, 30)).astype(np.float64)
label_np = np.random.random(size=(20, 30)).astype(np.float64)
prog = fluid.Program()
startup_prog = fluid.Program()
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
reductions = ['sum', 'mean', 'none']
for place in places:
for red in reductions:
with fluid.program_guard(prog, startup_prog):
input = fluid.data(
name='input', shape=[None, 30], dtype='float64')
label = fluid.data(
name='label', shape=[None, 30], dtype='float64')
bce_loss = paddle.nn.loss.BCELoss(reduction=red)
res = bce_loss(input, label)
exe = fluid.Executor(place)
static_result = exe.run(
prog,
feed={"input": input_np,
"label": label_np},
fetch_list=[res])
with fluid.dygraph.guard():
bce_loss = paddle.nn.loss.BCELoss(reduction=red)
dy_res = bce_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy()
expected = -1 * (label_np * np.log(input_np) +
(1. - label_np) * np.log(1. - input_np))
if red == 'mean':
expected = np.mean(expected)
elif red == 'sum':
expected = np.sum(expected)
else:
expected = expected
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
def test_BCELoss_weight(self):
input_np = np.random.random(size=(20, 30)).astype(np.float64)
label_np = np.random.random(size=(20, 30)).astype(np.float64)
weight_np = np.random.random(size=(20, 30)).astype(np.float64)
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(name='input', shape=[None, 30], dtype='float64')
label = fluid.data(name='label', shape=[None, 30], dtype='float64')
weight = fluid.data(
name='weight', shape=[None, 30], dtype='float64')
bce_loss = paddle.nn.loss.BCELoss(weight=weight)
res = bce_loss(input, label)
exe = fluid.Executor(place)
static_result = exe.run(prog,
feed={
"input": input_np,
"label": label_np,
"weight": weight_np
},
fetch_list=[res])
with fluid.dygraph.guard():
bce_loss = paddle.nn.loss.BCELoss(
weight=fluid.dygraph.to_variable(weight_np))
dy_res = bce_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_result = dy_res.numpy()
expected = np.mean(-1 * weight_np *
(label_np * np.log(input_np) +
(1. - label_np) * np.log(1. - input_np)))
self.assertTrue(np.allclose(static_result, expected))
self.assertTrue(np.allclose(static_result, dy_result))
self.assertTrue(np.allclose(dy_result, expected))
def bce_loss(input, label):
return -1 * (label * np.log(input) + (1. - label) * np.log(1. - input))
class TestBceLossOp(OpTest):
def setUp(self):
self.init_test_case()
self.op_type = "bce_loss"
input_np = np.random.uniform(0.1, 0.8, self.shape).astype("float64")
label_np = np.random.randint(0, 2, self.shape).astype("float64")
output_np = bce_loss(input_np, label_np)
self.inputs = {'X': input_np, 'Label': label_np}
self.outputs = {'Out': output_np}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
def init_test_case(self):
self.shape = [10, 10]
if __name__ == "__main__":
unittest.main()
......@@ -66,7 +66,7 @@ from .layer import loss #DEFINE_ALIAS
from .layer import conv #DEFINE_ALIAS
from .layer.conv import Conv2D, Conv2DTranspose, Conv3D, Conv3DTranspose #DEFINE_ALIAS
# from .layer.loss import NLLLoss #DEFINE_ALIAS
# from .layer.loss import BCELoss #DEFINE_ALIAS
from .layer.loss import BCELoss #DEFINE_ALIAS
# from .layer.learning_rate import CosineDecay #DEFINE_ALIAS
# from .layer.learning_rate import ExponentialDecay #DEFINE_ALIAS
# from .layer.learning_rate import InverseTimeDecay #DEFINE_ALIAS
......
......@@ -20,7 +20,7 @@ __all__ = [
# 'MSELoss',
'L1Loss',
# 'NLLLoss',
# 'BCELoss'
'BCELoss'
]
......@@ -109,3 +109,112 @@ class L1Loss(fluid.dygraph.Layer):
return fluid.layers.reduce_mean(unreduced)
else:
return unreduced
class BCELoss(fluid.dygraph.Layer):
"""
This op accepts input predictions and target label and returns binary
cross entropy error.
For predictions label, and target label, the loss is calculated as follows.
If :attr:`weight` is set, the loss is:
Out = -1 * weight * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`weight` is None, the loss is:
Out = -1 * (label * log(input) + (1 - label) * log(1 - input))
If :attr:`reduction` set to ``'none'``, the unreduced loss is:
.. math::
Out = Out
If :attr:`reduction` set to ``'mean'``, the reduced mean loss is:
.. math::
Out = MEAN(Out)
If :attr:`reduction` set to ``'sum'``, the reduced sum loss is:
.. math::
Out = SUM(Out)
Parameters:
input (Variable): Input tensor, the data type is float32,
float64. Input must in (0, 1).
label (Variable): Label tensor, has the same shape with input,
the data type is float32, float64.
weight (Variable, optional): Weight tensor, a manual rescaling weight given
to each class. It has the same dimensions as class number and the data type
is float32, float64, int32, int64. Default is ``'None'``.
reduction (str, optional): Indicate how to average the loss by batch_size,
the candicates are ``'none'`` | ``'mean'`` | ``'sum'``.
If :attr:`reduction` is ``'mean'``, the reduced mean loss is returned;
Default is ``'mean'``.
Returns:
The tensor variable storing the bce_loss of input and label.
Return type: Variable.
Examples:
.. code-block:: python
# declarative mode
import paddle.fluid as fluid
import numpy as np
import paddle
input = fluid.data(name="input", shape=[3, 1], dtype='float32')
label = fluid.data(name="label", shape=[3, 1], dtype='float32')
bce_loss = paddle.nn.loss.BCELoss()
output = bce_loss(input, label)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
input_data = np.array([0.5, 0.6, 0.7]).astype("float32")
label_data = np.array([1.0, 0.0, 1.0]).astype("float32")
output_data = exe.run(fluid.default_main_program(),
feed={"input":input_data, "label":label_data},
fetch_list=[output],
return_numpy=True)
print(output_data) # [array([0.65537095], dtype=float32)]
# imperative mode
import paddle.fluid.dygraph as dg
with dg.guard(place) as g:
input = dg.to_variable(input_data)
label = dg.to_variable(label_data)
output = bce_loss(input, label)
print(output.numpy()) # [0.65537095]
"""
def __init__(self, weight=None, reduction='mean'):
if reduction not in ['sum', 'mean', 'none']:
raise ValueError(
"The value of 'reduction' in bce_loss should be 'sum', 'mean' or 'none', but "
"received %s, which is not allowed." % reduction)
super(BCELoss, self).__init__()
self.weight = weight
self.reduction = reduction
def forward(self, input, label):
dtype = self._helper.input_dtype(input)
fluid.data_feeder.check_variable_and_dtype(
input, 'input', ['float32', 'float64'], 'bce_loss')
fluid.data_feeder.check_variable_and_dtype(
label, 'label', ['float32', 'float64'], 'bce_loss')
out = self._helper.create_variable_for_type_inference(dtype=input.dtype)
self._helper.append_op(
type='bce_loss',
inputs={
'X': [input],
'Label': [label],
},
outputs={'Out': [out]})
if self.weight is not None:
if isinstance(self.weight, fluid.framework.Variable):
w = self.weight
out = fluid.layers.elementwise_mul(out, w, axis=0)
else:
raise ValueError(
"The weight is not a Variable, please convert to Variable.")
if self.reduction == 'sum':
return fluid.layers.reduce_sum(out)
elif self.reduction == 'mean':
return fluid.layers.reduce_mean(out)
else:
return out
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册