diff --git a/paddle/operators/auc_op.cc b/paddle/operators/auc_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..d8cecf0957c6cb296b20746422b83d3402c11d11 --- /dev/null +++ b/paddle/operators/auc_op.cc @@ -0,0 +1,85 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/auc_op.h" + +namespace paddle { +namespace operators { + +class AucOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Inference"), + "Input of Inference must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Label"), + "Input of Label must be initialized."); + auto inference_dim = ctx->GetInputDim("Inference"); + auto label_dim = ctx->GetInputDim("Label"); + + PADDLE_ENFORCE_EQ(inference_dim, label_dim, + "inference and label should have same shape"); + + ctx->SetOutputDim("AUC", {1}); + ctx->ShareLoD("Inference", /*->*/ "AUC"); + } +}; + +class AucOpMaker : public framework::OpProtoAndCheckerMaker { + public: + AucOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Inference", + "A floating point tensor of arbitrary shape and whose values" + "are in the range [0, 1]."); + AddInput("Label", + "A tensor whose shape matches " + "Inference. Will be cast to bool."); + // TODO(typhoonzero): support weight input + AddOutput("AUC", + "A scalar representing the " + "current area-under-curve."); + + AddAttr("curve", "Curve type, can be 'ROC' or 'PR'.") + .SetDefault("ROC"); + AddAttr("num_thresholds", + "The number of thresholds to use when discretizing the" + " roc curve.") + .SetDefault(200); + + AddComment( + R"DOC(Computes the AUC according forward output and label. + Best to use for binary classification evaluations. + + If input label contains values other than 0 and 1, it will be cast + to bool. + + You can find the definations here: + https://en.wikipedia.org/wiki/Receiver_operating_characteristic#Area_under_the_curve + + Possible curves are: + - ROC: Receiver operating characteristic + - PR: Precision Recall + )DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_WITHOUT_GRADIENT(auc, ops::AucOp, ops::AucOpMaker); +REGISTER_OP_CPU_KERNEL(auc, ops::AucKernel); diff --git a/paddle/operators/auc_op.h b/paddle/operators/auc_op.h new file mode 100644 index 0000000000000000000000000000000000000000..be6ef29d5f6cff5b9ebdf7d8564b2e2792c3b5cb --- /dev/null +++ b/paddle/operators/auc_op.h @@ -0,0 +1,135 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +using EigenVector = framework::EigenVector; + +template +class AucKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* inference = ctx.Input("Inference"); + auto* label = ctx.Input("Label"); + auto* auc = ctx.Output("AUC"); + + float* auc_data = auc->mutable_data(ctx.GetPlace()); + + std::string curve = ctx.Attr("curve"); + int num_thresholds = ctx.Attr("num_thresholds"); + std::vector thresholds_list; + thresholds_list.reserve(num_thresholds); + for (int i = 1; i < num_thresholds - 1; i++) { + thresholds_list[i] = (float)i / (num_thresholds - 1); + } + const float kEpsilon = 1e-7; + thresholds_list[0] = 0.0f - kEpsilon; + thresholds_list[num_thresholds - 1] = 1.0f + kEpsilon; + + size_t num_samples = inference->numel(); + + const T* inference_data = inference->data(); + Tensor label_casted; + label_casted.Resize(label->dims()); + bool* label_casted_data = label_casted.mutable_data(ctx.GetPlace()); + + const int* label_data = label->data(); + // cast label_data to bool + for (size_t i = 0; i < num_samples; i++) { + label_casted_data[i] = static_cast(label_data[i]); + } + + // Create local tensor for storing the curve: TP, FN, TN, FP + // TODO(typhoonzero): use eigen op to caculate these values. + Tensor true_positive, false_positive, true_negative, false_negative; + + true_positive.Resize({num_thresholds}); + false_negative.Resize({num_thresholds}); + true_negative.Resize({num_thresholds}); + false_positive.Resize({num_thresholds}); + + int* tp_data = true_positive.mutable_data(ctx.GetPlace()); + int* fn_data = false_negative.mutable_data(ctx.GetPlace()); + int* tn_data = true_negative.mutable_data(ctx.GetPlace()); + int* fp_data = false_positive.mutable_data(ctx.GetPlace()); + + for (int idx_thresh = 0; idx_thresh < num_thresholds; idx_thresh++) { + // caculate TP, FN, TN, FP for current thresh + int tp = 0, fn = 0, tn = 0, fp = 0; + for (size_t i = 0; i < num_samples; i++) { + if (label_casted_data[i]) { + if (inference_data[i] >= (thresholds_list[idx_thresh])) { + tp++; + } else { + fn++; + } + } else { + if (inference_data[i] >= (thresholds_list[idx_thresh])) { + fp++; + } else { + tn++; + } + } + } + // store rates + tp_data[idx_thresh] = tp; + fn_data[idx_thresh] = fn; + tn_data[idx_thresh] = tn; + fp_data[idx_thresh] = fp; + } + // epsilon to avoid divide by zero. + float epsilon = 1e-6; + // Riemann sum to caculate auc. + Tensor tp_rate, fp_rate, rec_rate; + tp_rate.Resize({num_thresholds}); + fp_rate.Resize({num_thresholds}); + rec_rate.Resize({num_thresholds}); + float* tp_rate_data = tp_rate.mutable_data(ctx.GetPlace()); + float* fp_rate_data = fp_rate.mutable_data(ctx.GetPlace()); + float* rec_rate_data = rec_rate.mutable_data(ctx.GetPlace()); + for (int i = 0; i < num_thresholds; i++) { + tp_rate_data[i] = + ((float)tp_data[i] + epsilon) / (tp_data[i] + fn_data[i] + epsilon); + fp_rate_data[i] = (float)fp_data[i] / (fp_data[i] + tn_data[i] + epsilon); + rec_rate_data[i] = + ((float)tp_data[i] + epsilon) / (tp_data[i] + fp_data[i] + epsilon); + } + *auc_data = 0.0f; + if (curve == "ROC") { + for (int i = 0; i < num_thresholds - 1; i++) { + auto dx = fp_rate_data[i] - fp_rate_data[i + 1]; + auto y = (tp_rate_data[i] + tp_rate_data[i + 1]) / 2.0f; + *auc_data = *auc_data + dx * y; + } + } else if (curve == "PR") { + for (int i = 1; i < num_thresholds; i++) { + auto dx = tp_rate_data[i] - tp_rate_data[i - 1]; + auto y = (rec_rate_data[i] + rec_rate_data[i - 1]) / 2.0f; + *auc_data = *auc_data + dx * y; + } + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/huber_loss_op.cc b/paddle/operators/huber_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..2d9449f5ca50dab8d2a7928c4311ec2d66b47904 --- /dev/null +++ b/paddle/operators/huber_loss_op.cc @@ -0,0 +1,122 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/huber_loss_op.h" + +namespace paddle { +namespace operators { + +class HuberLossOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must be initialized."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must be initialized."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + + PADDLE_ENFORCE_EQ(x_dims, y_dims); + PADDLE_ENFORCE_EQ(x_dims.size(), 2, + "The rank of Input(X) must be 2 and the shape is " + "[batch_size, 1]."); + PADDLE_ENFORCE_EQ(x_dims[1], 1, + "Each row of Input(X) contains a real value, " + "so the 2nd dimension of Input(X) must be 1."); + + ctx->SetOutputDim("Residual", x_dims); + ctx->SetOutputDim("Out", {x_dims[0], 1}); + ctx->ShareLoD("X", "Out"); + } +}; + +template +class HuberLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + HuberLossOpMaker(framework::OpProto* proto, + framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "The input value of huber loss op." + "X is a 2-D tensor with shape [batch_size, 1]."); + AddInput("Y", + "The target value of huber loss op." + "Y is a 2-D tensor with shape [batch_size, 1]."); + AddOutput("Residual", + "Intermediate tensor to cache residual value between Y and X." + "The shape is same as Input(X) and will be reused in backward.") + .AsIntermediate(); + AddOutput("Out", + "The output tensor with shape [batch_size, 1] which represents " + "the huber loss."); + AddAttr("delta", "Hyper parameter in huber loss."); + AddComment(R"DOC( +Huber loss is a loss function used in robust regression. We define X as the +input value and Y as the target value. Huber loss can evaluate the fitness of +X to Y. Different from MSE loss, Huber loss is more robust for outliers. The +shape of X and Y are [batch_size, 1]. The equation is: + +L_{\delta}(y, f(x)) = +\begin{cases} +0.5 * (y - f(x))^2, \quad |y - f(x)| \leq \delta \\ +\delta * (|y - f(x)| - 0.5 * \delta), \quad otherwise +\end{cases} + +)DOC"); + } +}; + +class HuberLossGradOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null."); + PADDLE_ENFORCE(ctx->HasInput("Residual"), + "Input(Residual) should not be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) should not be null."); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + auto residual_dims = ctx->GetInputDim("Residual"); + auto out_grad_dims = ctx->GetInputDim(framework::GradVarName("Out")); + + PADDLE_ENFORCE_EQ(residual_dims, x_dims); + PADDLE_ENFORCE_EQ(out_grad_dims, x_dims); + + auto x_grad_name = framework::GradVarName("X"); + auto y_grad_name = framework::GradVarName("Y"); + if (ctx->HasOutput(x_grad_name)) { + ctx->SetOutputDim(x_grad_name, x_dims); + } + if (ctx->HasOutput(y_grad_name)) { + ctx->SetOutputDim(y_grad_name, y_dims); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(huber_loss, ops::HuberLossOp, ops::HuberLossOpMaker, + huber_loss_grad, ops::HuberLossGradOp); +REGISTER_OP_CPU_KERNEL(huber_loss, + ops::HuberLossKernel); +REGISTER_OP_CPU_KERNEL( + huber_loss_grad, + ops::HuberLossGradKernel); diff --git a/paddle/operators/huber_loss_op.cu b/paddle/operators/huber_loss_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..317321dc6c495f6e9a8808d841c71bfa26b754d0 --- /dev/null +++ b/paddle/operators/huber_loss_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/huber_loss_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(huber_loss, + ops::HuberLossKernel); +REGISTER_OP_GPU_KERNEL( + huber_loss_grad, + ops::HuberLossGradKernel); diff --git a/paddle/operators/huber_loss_op.h b/paddle/operators/huber_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..4e7bc5543226e19fe0d6190171cdd9c2b3d2d985 --- /dev/null +++ b/paddle/operators/huber_loss_op.h @@ -0,0 +1,119 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenVector = framework::EigenVector; + +template +struct HuberLossForward { + HOSTDEVICE HuberLossForward(const T& delta) : delta(delta) {} + + HOSTDEVICE T operator()(const T& val) const { + T abs_val = std::abs(val); + if (abs_val <= delta) { + return static_cast(0.5) * val * val; + } else { + return delta * (abs_val - static_cast(0.5) * delta); + } + } + + T delta; +}; + +template +class HuberLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("X"); + auto* in1 = context.Input("Y"); + auto* out0 = context.Output("Residual"); + auto* out1 = context.Output("Out"); + auto delta = static_cast(context.Attr("delta")); + auto place = context.GetEigenDevice(); + + auto x = EigenVector::Flatten(*in0); + auto y = EigenVector::Flatten(*in1); + out0->mutable_data(context.GetPlace()); + auto residual = EigenVector::Flatten(*out0); + residual.device(place) = y - x; + out1->mutable_data(context.GetPlace()); + auto loss = EigenVector::Flatten(*out1); + loss.device(place) = residual.unaryExpr(HuberLossForward(delta)); + } +}; + +template +struct HuberLossBackward { + HOSTDEVICE HuberLossBackward(const T& delta, T sign) + : sign(sign), delta(delta) {} + + HOSTDEVICE T operator()(const T& val) const { + T abs_val = std::abs(val); + if (abs_val <= delta) { + return sign * val; + } else { + if (val > 0) { + return sign * delta; + } else { + return -1 * sign * delta; + } + } + } + + T sign; + T delta; +}; + +template +class HuberLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* in0 = context.Input("Residual"); + auto* in1 = context.Input(framework::GradVarName("Out")); + auto* out0 = context.Output(framework::GradVarName("X")); + auto* out1 = context.Output(framework::GradVarName("Y")); + auto delta = static_cast(context.op().Attr("delta")); + auto place = context.GetEigenDevice(); + + auto residual = EigenVector::Flatten(*in0); + auto out_grad = EigenVector::Flatten(*in1); + + if (out0) { + out0->mutable_data(context.GetPlace()); + auto x_grad = EigenVector::Flatten(*out0); + x_grad.device(place) = + out_grad * residual.unaryExpr(HuberLossBackward(delta, -1.0)); + } + + if (out1) { + out1->mutable_data(context.GetPlace()); + auto y_grad = EigenVector::Flatten(*out1); + y_grad.device(place) = + out_grad * residual.unaryExpr(HuberLossBackward(delta, 1.0)); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_auc_op.py b/python/paddle/v2/framework/tests/test_auc_op.py new file mode 100644 index 0000000000000000000000000000000000000000..f458e01fc567de42a7afaa37d94237f930caf3f1 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_auc_op.py @@ -0,0 +1,66 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestAucOp(OpTest): + def setUp(self): + self.op_type = "auc" + pred = np.random.random((128)).astype("float32") + labels = np.random.randint(0, 2, (128, )) + num_thresholds = 200 + self.inputs = {'Inference': pred, 'Label': labels} + self.attrs = {'curve': 'ROC', 'num_thresholds': num_thresholds} + # NOTE: sklearn use a different way to generate thresholds + # which will cause the result differs slightly: + # from sklearn.metrics import roc_curve, auc + # fpr, tpr, thresholds = roc_curve(labels, pred) + # auc_value = auc(fpr, tpr) + # we caculate AUC again using numpy for testing + kepsilon = 1e-7 # to account for floating point imprecisions + thresholds = [(i + 1) * 1.0 / (num_thresholds - 1) + for i in range(num_thresholds - 2)] + thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] + + # caculate TP, FN, TN, FP count + tp_list = np.ndarray((num_thresholds, )) + fn_list = np.ndarray((num_thresholds, )) + tn_list = np.ndarray((num_thresholds, )) + fp_list = np.ndarray((num_thresholds, )) + for idx_thresh, thresh in enumerate(thresholds): + tp, fn, tn, fp = 0, 0, 0, 0 + for i, lbl in enumerate(labels): + if lbl: + if pred[i] >= thresh: + tp += 1 + else: + fn += 1 + else: + if pred[i] >= thresh: + fp += 1 + else: + tn += 1 + tp_list[idx_thresh] = tp + fn_list[idx_thresh] = fn + tn_list[idx_thresh] = tn + fp_list[idx_thresh] = fp + + epsilon = 1e-6 + tpr = (tp_list.astype("float32") + epsilon) / ( + tp_list + fn_list + epsilon) + fpr = fp_list.astype("float32") / (fp_list + tn_list + epsilon) + rec = (tp_list.astype("float32") + epsilon) / ( + tp_list + fp_list + epsilon) + + x = fpr[:num_thresholds - 1] - fpr[1:] + y = (tpr[:num_thresholds - 1] + tpr[1:]) / 2.0 + auc_value = np.sum(x * y) + + self.outputs = {'AUC': auc_value} + + def test_check_output(self): + self.check_output() + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/v2/framework/tests/test_huber_loss_op.py b/python/paddle/v2/framework/tests/test_huber_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f102d4fc02a524514d4ba1cd523ddc1e4604ea --- /dev/null +++ b/python/paddle/v2/framework/tests/test_huber_loss_op.py @@ -0,0 +1,47 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def huber_loss_forward(val, delta): + abs_val = abs(val) + if abs_val <= delta: + return 0.5 * val * val + else: + return delta * (abs_val - 0.5 * delta) + + +class TestHuberLossOp(OpTest): + def setUp(self): + self.op_type = 'huber_loss' + samples_num = 64 + delta = 1.0 + self.inputs = { + 'X': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'), + 'Y': np.random.uniform(0, 1., (samples_num, 1)).astype('float32'), + } + residual = self.inputs['Y'] - self.inputs['X'] + loss = np.vectorize(huber_loss_forward)(residual, delta) + self.attrs = {'delta': delta} + self.outputs = { + 'Residual': residual, + 'Out': loss.reshape((samples_num, 1)) + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.008) + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.008, no_grad_set=set("residual")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.008, no_grad_set=set('residual')) + + +if __name__ == '__main__': + unittest.main()