未验证 提交 3d939d32 编写于 作者: K Kaipeng Deng 提交者: GitHub

Merge pull request #16023 from heavengate/kl_div_loss

KL div loss: add kldiv_loss op
......@@ -230,6 +230,7 @@ paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func',
paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '1546136806fef5c08f6918544bd9151d'))
paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', '2f6ff96864054a31aa4bb659c6722c99'))
paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '431a4301c35032166ec029f7432c80a7'))
paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '776d536cac47c89073abc7ee524d5aec'))
paddle.fluid.layers.tree_conv (ArgSpec(args=['nodes_vector', 'edge_set', 'output_size', 'num_filters', 'max_depth', 'act', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 2, 'tanh', None, None, None)), ('document', '34ea12ac9f10a65dccbc50100d12e607'))
paddle.fluid.layers.npair_loss (ArgSpec(args=['anchor', 'positive', 'labels', 'l2_reg'], varargs=None, keywords=None, defaults=(0.002,)), ('document', '46994d10276dd4cb803b4062b5d14329'))
paddle.fluid.layers.fsp_matrix (ArgSpec(args=['x', 'y'], varargs=None, keywords=None, defaults=None), ('document', 'b76ccca3735bea4a58a0dbf0d77c5393'))
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/kldiv_loss_op.h"
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using framework::Tensor;
class KLDivLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of KLDivLossOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Target"),
"Input(Target) of KLDivLossOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Loss"),
"Output(Loss) of KLDivLossOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
auto dim_target = ctx->GetInputDim("Target");
PADDLE_ENFORCE_EQ(dim_x.size(), dim_target.size(),
"Input(X) rank and Input(Target) rank should be same.");
for (int i = 0; i < dim_x.size(); i++) {
PADDLE_ENFORCE_EQ(dim_x[i], dim_target[i],
"Input(X) and Input(Target) should in same shape.");
}
auto reduction = ctx->Attrs().Get<std::string>("reduction");
PADDLE_ENFORCE(
"mean" == reduction || "sum" == reduction || "batchmean" == reduction ||
"none" == reduction,
"Attr(reduction) can only be 'none'|'batchmean'|'sum'|'mean'.");
if ("none" == reduction) {
ctx->SetOutputDim("Loss", dim_x);
} else {
ctx->SetOutputDim("Loss", {1});
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
class KLDivLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The input tensor of KL divergence loss operator. "
"This is a tensor with shape of [N, *], where N is the "
"batch size, * means any number of additional dimensions.");
AddInput("Target",
"The tensor of KL divergence loss operator. "
"This is a tensor with shape of Input(X).");
AddOutput(
"Loss",
"The output KL divergence loss tensor. if Attr(reduction) is "
"'none', this tensor should be in same shape of of Input(X), else "
"this tensor should be in shape of [1].");
AddAttr<std::string>(
"reduction",
"The reduction type to apply to the output, available types "
"are 'none' | 'batchmean' | 'mean' | 'sum', 'none' for no "
"reduction, 'batchmean' for the sum of output divided by "
"batch size, 'mean' for the average value of all output, "
"'sum' for the sum of the output.")
.SetDefault("mean");
AddComment(R"DOC(
This operator calculates the Kullback-Leibler divergence loss
between Input(X) and Input(Target).
KL divergence loss is calculated as follows:
$$l(x, y) = y * (\log(y) - x)$$
While :math:`x` is Input(X) and :math:`y` is Input(Target).
While :attr:`reduction` is :attr:`none`, output loss is in
the same shape as Input(X), loss in each point is calculated
seperately and no reduction is applied.
While :attr:`reduction` is :attr:`mean`, output loss is in
shape of [1] and loss value is the mean value of all losses.
While :attr:`reduction` is :attr:`sum`, output loss is in
shape of [1] and loss value is the sum value of all losses.
While :attr:`reduction` is :attr:`batchmean`, output loss is
in shape of [1] and loss value is the sum value of all losses
divided by batch size.
)DOC");
}
};
class KLDivLossOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Target"), "Input(Target) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
"Input(Loss@GRAD) should not be null");
auto dim_x = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.GetPlace());
}
};
class KLDivLossOpGradMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
auto* op = new framework::OpDesc();
op->SetType("kldiv_loss_grad");
op->SetInput("X", Input("X"));
op->SetInput("Target", Input("Target"));
op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
op->SetAttrMap(Attrs());
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return std::unique_ptr<framework::OpDesc>(op);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(kldiv_loss, ops::KLDivLossOp, ops::KLDivLossOpMaker,
ops::KLDivLossOpGradMaker);
REGISTER_OPERATOR(kldiv_loss_grad, ops::KLDivLossOpGrad);
REGISTER_OP_CPU_KERNEL(
kldiv_loss, ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, float>,
ops::KLDivLossKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
kldiv_loss_grad,
ops::KLDivLossGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::KLDivLossGradKernel<paddle::platform::CPUDeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/kldiv_loss_op.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
kldiv_loss,
ops::KLDivLossKernel<paddle::platform::CUDADeviceContext, float>,
ops::KLDivLossKernel<paddle::platform::CUDADeviceContext, double>);
REGISTER_OP_CUDA_KERNEL(
kldiv_loss_grad,
ops::KLDivLossGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::KLDivLossGradKernel<paddle::platform::CUDADeviceContext, double>);
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/hostdevice.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
using Array1 = Eigen::DSizes<int64_t, 1>;
template <typename T>
struct KLDivLossForward {
HOSTDEVICE KLDivLossForward() {}
HOSTDEVICE T operator()(const T& target, const T& input) const {
if (target <= 0) {
return 0;
} else {
return target * (std::log(target) - input);
}
}
};
template <typename T>
struct KLDivLossBackward {
HOSTDEVICE KLDivLossBackward() {}
HOSTDEVICE T operator()(const T& target, const T& grad) const {
if (target <= 0) {
return 0;
} else {
return static_cast<T>(-1.) * grad;
}
}
};
template <typename DeviceContext, typename T>
class KLDivLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto* input = ctx.Input<Tensor>("X");
auto* target = ctx.Input<Tensor>("Target");
auto* loss = ctx.Output<Tensor>("Loss");
auto reduction = ctx.Attr<std::string>("reduction");
const int n = input->dims()[0];
loss->mutable_data<T>(ctx.GetPlace());
auto input_t = EigenVector<T>::Flatten(*input);
auto target_t = EigenVector<T>::Flatten(*target);
auto loss_t = EigenVector<T>::Flatten(*loss);
auto output = target_t.binaryExpr(input_t, KLDivLossForward<T>());
if ("none" == reduction) {
loss_t.device(place) = output;
} else if ("batchmean" == reduction) {
auto output_sum = output.sum().eval();
loss_t.device(place) = output_sum / output_sum.constant(n);
} else if ("mean" == reduction) {
loss_t.device(place) = output.mean();
} else if ("sum" == reduction) {
loss_t.device(place) = output.sum();
}
}
};
template <typename DeviceContext, typename T>
class KLDivLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& place = *ctx.template device_context<DeviceContext>().eigen_device();
auto* target = ctx.Input<Tensor>("Target");
auto reduction = ctx.Attr<std::string>("reduction");
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* loss_grad = ctx.Input<Tensor>(framework::GradVarName("Loss"));
const int n = input_grad->dims()[0];
const int numel = input_grad->numel();
const int expand = numel / loss_grad->numel();
input_grad->mutable_data<T>(ctx.GetPlace());
auto target_t = EigenVector<T>::Flatten(*target);
auto input_grad_t = EigenVector<T>::Flatten(*input_grad);
auto loss_grad_t = EigenVector<T>::Flatten(*loss_grad);
auto loss_grad_expand = loss_grad_t.broadcast(Array1(expand));
auto grad_t = target_t * loss_grad_expand;
input_grad_t.device(place) =
target_t.binaryExpr(grad_t, KLDivLossBackward<T>());
if ("mean" == reduction) {
input_grad_t.device(place) = input_grad_t / static_cast<T>(numel);
} else if ("batchmean" == reduction) {
input_grad_t.device(place) = input_grad_t / static_cast<T>(n);
}
}
};
} // namespace operators
} // namespace paddle
......@@ -188,6 +188,7 @@ __all__ = [
'psroi_pool',
'teacher_student_sigmoid_loss',
'huber_loss',
'kldiv_loss',
'tree_conv',
'npair_loss',
'fsp_matrix',
......@@ -10762,6 +10763,38 @@ def huber_loss(input, label, delta):
return out
@templatedoc()
def kldiv_loss(x, target, reduction='mean', name=None):
"""
${comment}
Args:
x (Variable): ${x_comment}
target (Variable): ${target_comment}
reduction (Variable): ${reduction_comment}
name (str, default None): The name of this layer.
Returns:
kldiv\_loss (Variable): The KL divergence loss.
Examples:
.. code-block:: python
x = fluid.layers.data(name='x', shape=[4,2,2], dtype='float32')
target = fluid.layers.data(name='target', shape=[4,2,2], dtype='float32')
loss = fluid.layers.kldiv_loss(x=x, target=target, reduction='batchmean')
"""
helper = LayerHelper('kldiv_loss', **locals())
loss = helper.create_variable_for_type_inference(dtype=x.dtype)
helper.append_op(
type='kldiv_loss',
inputs={'X': x,
'Target': target},
outputs={'Loss': loss},
attrs={'reduction': reduction})
return loss
@templatedoc()
def tree_conv(nodes_vector,
edge_set,
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
import unittest
import numpy as np
from op_test import OpTest
def kldiv_loss(x, target, reduction):
output = target * (np.log(target) - x)
loss = np.where(target >= 0, output, np.zeros_like(x))
if reduction == "batchmean":
return loss.sum() / x.shape[0]
if reduction == "mean":
return loss.mean()
if reduction == "sum":
return loss.sum()
return loss
class TestKLDivLossOp(OpTest):
def setUp(self):
self.initTestCase()
self.op_type = 'kldiv_loss'
x = np.random.uniform(-10, 10, self.x_shape).astype('float32')
target = np.random.uniform(-10, 10, self.x_shape).astype('float32')
self.attrs = {"reduction": self.reduction}
self.inputs = {
'X': x,
'Target': target,
}
loss = kldiv_loss(x, target, self.reduction)
self.outputs = {'Loss': loss.astype('float32')}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(
['X'], 'Loss', no_grad_set=set(["Target"]), max_relative_error=0.06)
def initTestCase(self):
self.x_shape = (2, 5, 5)
self.reduction = 'batchmean'
class TestKLDivLossOp2(TestKLDivLossOp):
def initTestCase(self):
self.x_shape = (3, 2, 7, 7)
self.reduction = 'none'
class TestKLDivLossOp3(TestKLDivLossOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 7, 9)
self.reduction = 'mean'
class TestKLDivLossOp4(TestKLDivLossOp):
def initTestCase(self):
self.x_shape = (5, 7)
self.reduction = 'sum'
if __name__ == "__main__":
unittest.main()
......@@ -1591,6 +1591,15 @@ class TestBook(unittest.TestCase):
out = layers.spectral_norm(weight, dim=1, power_iters=1)
self.assertIsNotNone(out)
def test_kldiv_loss(self):
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[32, 128, 128], dtype="float32")
target = layers.data(
name='target', shape=[32, 128, 128], dtype="float32")
loss = layers.kldiv_loss(x=x, target=target, reduction='batchmean')
self.assertIsNotNone(loss)
print(str(program))
def test_temporal_shift(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册