提交 2d42fa74 编写于 作者: S sweetsky0901

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into my_unpool_max_2d

/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/log_loss_op.h"
namespace paddle {
namespace operators {
class LogLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Predicted"),
"Input(Predicted) must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Labels"),
"Input(Labels) must be initialized.");
auto pred_dims = ctx->GetInputDim("Predicted");
auto label_dims = ctx->GetInputDim("Labels");
PADDLE_ENFORCE_EQ(pred_dims, label_dims);
PADDLE_ENFORCE_EQ(pred_dims.size(), 2,
"The rank of Input(Predicted) must be 2 and the shape is "
"[batch_size, 1].");
PADDLE_ENFORCE_EQ(pred_dims[1], 1,
"Each row of Input(Predicted) contains a real value, "
"so the 2nd dimension of Input(X) must be 1.");
ctx->SetOutputDim("Loss", {pred_dims[0], 1});
ctx->ShareLoD("Predicted", "Loss");
}
};
template <typename AttrType>
class LogLossOpMaker : public framework::OpProtoAndCheckerMaker {
public:
LogLossOpMaker(framework::OpProto* proto,
framework::OpAttrChecker* op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("Predicted",
"The input value (Predicted) of Log loss op."
"Predicted is a 2-D tensor with shape [batch_size, 1].");
AddInput("Labels",
"The target value (Labels) of Log loss op."
"Labels is a 2-D tensor with shape [batch_size, 1].");
AddOutput("Loss",
"The output tensor with shape [batch_size, 1] "
"which represents the log loss.");
AddAttr<AttrType>("epsilon", "Epsilon in log loss.");
AddComment(R"DOC(
LogLoss Operator.
Log loss is a loss function used for binary classification. Log Loss quantifies
the accuracy of a classifier by penalising false classifications. Minimising the
Log Loss is equivalent to maximising the accuracy of the classifier. We define
Predicted as the values predicted by our model and Labels as the target ground
truth value. Log loss can evaluate how close the predicted values are to the
target. The shapes of Predicted and Labels are both [batch_size, 1].
The equation is:
$$
Loss = - Labels * log(Predicted + \epsilon) -
(1 - Labels) * log(1 - Predicted + \epsilon)
$$
)DOC");
}
};
class LogLossGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Predicted"),
"Input(Predicted) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Labels"),
"Input(Labels) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
"Input(Loss@GRAD) should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("Predicted")),
"Output(Predicted@GRAD) should not be null.");
auto pred_dims = ctx->GetInputDim("Predicted");
auto label_dims = ctx->GetInputDim("Labels");
auto loss_grad_dims = ctx->GetInputDim(framework::GradVarName("Loss"));
PADDLE_ENFORCE_EQ(loss_grad_dims, pred_dims);
auto pred_grad_name = framework::GradVarName("Predicted");
ctx->SetOutputDim(pred_grad_name, pred_dims);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(log_loss, ops::LogLossOp, ops::LogLossOpMaker<float>, log_loss_grad,
ops::LogLossGradOp);
REGISTER_OP_CPU_KERNEL(log_loss,
ops::LogLossKernel<paddle::platform::CPUPlace, float>);
REGISTER_OP_CPU_KERNEL(
log_loss_grad, ops::LogLossGradKernel<paddle::platform::CPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#define EIGEN_USE_GPU
#include "paddle/operators/log_loss_op.h"
namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(log_loss,
ops::LogLossKernel<paddle::platform::GPUPlace, float>);
REGISTER_OP_GPU_KERNEL(
log_loss_grad, ops::LogLossGradKernel<paddle::platform::GPUPlace, float>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/framework/eigen.h"
#include "paddle/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenVector = framework::EigenVector<T, MajorType, IndexType>;
template <typename Place, typename T, typename AttrType = T>
class LogLossKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* loss_out = ctx.Output<Tensor>("Loss");
loss_out->mutable_data<T>(ctx.GetPlace());
auto epsilon = static_cast<T>(ctx.Attr<AttrType>("epsilon"));
auto prediction = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Predicted"));
auto label = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Labels"));
auto loss = EigenVector<T>::Flatten(*loss_out);
auto place = ctx.GetEigenDevice<Place>();
loss.device(place) = (-(label * (prediction + epsilon).log()) -
((static_cast<T>(1) - label) *
(static_cast<T>(1) - prediction + epsilon).log()));
}
};
template <typename Place, typename T, typename AttrType = T>
class LogLossGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto epsilon = static_cast<T>(ctx.Attr<AttrType>("epsilon"));
auto prediction = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Predicted"));
auto label = EigenVector<T>::Flatten(*ctx.Input<Tensor>("Labels"));
auto* dloss = ctx.Input<Tensor>(framework::GradVarName("Loss"));
auto* dpred = ctx.Output<Tensor>(framework::GradVarName("Predicted"));
auto dl = EigenVector<T>::Flatten(*dloss);
auto place = ctx.GetEigenDevice<Place>();
if (dpred) {
dpred->mutable_data<T>(ctx.GetPlace());
auto dx = framework::EigenVector<T>::Flatten(*dpred);
dx.device(place) = dl * (-(label / (prediction + epsilon)) +
((static_cast<T>(1) - label) /
(static_cast<T>(1) - prediction + epsilon)));
}
}
};
} // namespace operators
} // namespace paddle
......@@ -2400,6 +2400,14 @@ class CropLayer(LayerBase):
image_conf.img_size_y = input_layer.height
image_conf.channels = input_layer.size / (input_layer.width *
input_layer.height)
# only support for 4-dims inputs and NCHW order
if (len(self.config.inputs) == 2):
self.set_layer_height_width(
self.get_input_layer(1).height, self.get_input_layer(1).width)
self.set_layer_size(self.get_input_layer(1).size)
else:
self.set_layer_height_width(shape[-2], shape[-1])
self.set_layer_size(reduce(lambda x, y: x * y, shape[1:]))
@config_layer('batch_norm')
......@@ -3849,6 +3857,26 @@ class SwitchOrderLayer(LayerBase):
name, 'switch_order', 0, inputs=inputs, **xargs)
self.config.reshape_conf.height_axis.extend(reshape['height'])
self.config.reshape_conf.width_axis.extend(reshape['width'])
input_layer = self.get_input_layer(0)
if reshape is None:
self.set_layer_size(input_layer.size)
else:
in_h = input_layer.height
in_w = input_layer.width
out_dims = None
if input_layer.has_depth():
in_d = input_layer.depth
in_c = input_layer.size / in_h / in_w / in_d
# batch_size, depth, height, width, channel
out_dims = [0, in_d, in_h, in_w, in_c]
else:
in_c = input_layer.size / in_h / in_w
# batch_size, height, width, channel
out_dims = [0, in_h, in_w, in_c]
# Because (reshape['width'][0] > 0) always be true.
# So out_dims[0] won't be used.
size = reduce(lambda x, y: x * y, out_dims[reshape['width'][0]:])
self.set_layer_size(size)
@config_layer('scale_sub_region')
......
......@@ -6873,6 +6873,7 @@ def crop_layer(input, offset, axis=2, shape=None, name=None, layer_attr=None):
:param input: The input of this layer. If two inputs are given, the second one
will be regarded as the reference.
And the input must be 4-dims and in NCHW order.
:type input: LayerOutput | Sequence
:param offset: The crop offset.
:type offset: Sequence
......
import unittest
import numpy as np
from op_test import OpTest
class TestLogLossOp(OpTest):
def setUp(self):
self.op_type = 'log_loss'
samples_num = 32
predicted = np.random.uniform(0.1, 1.0,
(samples_num, 1)).astype("float32")
labels = np.random.randint(0, 2, (samples_num, 1)).astype("float32")
epsilon = 1e-4
self.inputs = {
'Predicted': predicted,
'Labels': labels,
}
self.attrs = {'epsilon': epsilon}
loss = -labels * np.log(predicted + epsilon) - (
1 - labels) * np.log(1 - predicted + epsilon)
self.outputs = {'Loss': loss}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Predicted'], 'Loss', max_relative_error=0.03)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册