diff --git a/paddle/operators/margin_rank_loss_op.cc b/paddle/operators/margin_rank_loss_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..3b9d551b83ba87e0ddcb2f51c9867a7404e88f9f --- /dev/null +++ b/paddle/operators/margin_rank_loss_op.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/margin_rank_loss_op.h" + +namespace paddle { +namespace operators { + +class MarginRankLossOp : public framework::OperatorWithKernel { + public: + MarginRankLossOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + // input check + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) shouldn't be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null"); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null"); + auto label_dims = ctx.Input("Label")->dims(); + auto x1_dims = ctx.Input("X1")->dims(); + auto x2_dims = ctx.Input("X2")->dims(); + PADDLE_ENFORCE((label_dims.size() == 1) && (x1_dims.size() == 1) && + (x2_dims.size() == 1), + "The rank of all inputs must be 1."); + PADDLE_ENFORCE((label_dims == x1_dims) && (x1_dims == x2_dims), + "All inputs must have the same size"); + ctx.Output("Out")->Resize(label_dims); + ctx.Output("Activated")->Resize(label_dims); + } +}; + +template +class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MarginRankLossOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("Label", "The label indicating X1 ranked higher than X2 or not."); + AddInput("X1", "The first input of MarginRankLossOp."); + AddInput("X2", "The second input of MarginRankLossOp"); + AddAttr("margin", "Margin for MarginRankLossOp").SetDefault(0); + AddOutput("Out", "The output loss of MarginRankLoss operator"); + AddOutput("Activated", + "Intermediate tensor to indicate " + "whether Output(Out) is activated") + .AsIntermediate(); + AddComment(R"DOC(MarginRankLoss operator + +loss(x1, x2, y) = max(0, -label * (x1-x2) + margin) + +)DOC"); + } +}; + +class MarginRankLossGradOp : public framework::OperatorWithKernel { + public: + MarginRankLossGradOp(const std::string &type, + const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), + "Input(Label) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Activated"), + "Intermediate(Activated) shouldn't be null."); + auto dims = ctx.Input("X1")->dims(); + auto *x1_grad = + ctx.Output(framework::GradVarName("X1")); + auto *x2_grad = + ctx.Output(framework::GradVarName("X2")); + if (x1_grad) { + x1_grad->Resize(dims); + } + if (x2_grad) { + x2_grad->Resize(dims); + } + } +}; + +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OP(margin_rank_loss, ops::MarginRankLossOp, + ops::MarginRankLossOpMaker, margin_rank_loss_grad, + ops::MarginRankLossGradOp); +REGISTER_OP_CPU_KERNEL( + margin_rank_loss, + ops::MarginRankLossKernel); +REGISTER_OP_CPU_KERNEL( + margin_rank_loss_grad, + ops::MarginRankLossGradKernel); diff --git a/paddle/operators/margin_rank_loss_op.cu b/paddle/operators/margin_rank_loss_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..81cbf2fe8837f7e46ec13ade233e27178b5be5e4 --- /dev/null +++ b/paddle/operators/margin_rank_loss_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/margin_rank_loss_op.h" + +REGISTER_OP_GPU_KERNEL( + margin_rank_loss, + paddle::operators::MarginRankLossKernel); +REGISTER_OP_GPU_KERNEL(margin_rank_loss_grad, + paddle::operators::MarginRankLossGradKernel< + paddle::platform::GPUPlace, float>); diff --git a/paddle/operators/margin_rank_loss_op.h b/paddle/operators/margin_rank_loss_op.h new file mode 100644 index 0000000000000000000000000000000000000000..cd6544f417b600fc4a588779e76164afa5783aa3 --- /dev/null +++ b/paddle/operators/margin_rank_loss_op.h @@ -0,0 +1,106 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +template +struct ReLU { + HOSTDEVICE T operator()(const T& val) const { + if (val < 0) { + return static_cast(0); + } else { + return val; + } + } +}; + +template +struct Heaviside { + HOSTDEVICE T operator()(const T& val) const { + if (val > 0) { + return static_cast(1); + } else { + return static_cast(0); + } + } +}; + +template +class MarginRankLossKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* out_t = ctx.Output("Out"); + auto* act_t = ctx.Output("Activated"); + + auto* label_t = ctx.Input("Label"); + auto* x1_t = ctx.Input("X1"); + auto* x2_t = ctx.Input("X2"); + + out_t->mutable_data(ctx.GetPlace()); + act_t->mutable_data(ctx.GetPlace()); + + auto margin = static_cast(ctx.Attr("margin")); + auto out = framework::EigenVector::Flatten(*out_t); + auto act = framework::EigenVector::Flatten(*act_t); + + auto label = framework::EigenVector::Flatten(*label_t); + auto x1 = framework::EigenVector::Flatten(*x1_t); + auto x2 = framework::EigenVector::Flatten(*x2_t); + + auto& dev = ctx.GetEigenDevice(); + act.device(dev) = (-label * (x1 - x2) + margin).unaryExpr(Heaviside()); + out.device(dev) = (-label * (x1 - x2) + margin).unaryExpr(ReLU()); + } +}; + +template +class MarginRankLossGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + auto* d_x1_t = + ctx.Output(framework::GradVarName("X1")); + auto* d_x2_t = + ctx.Output(framework::GradVarName("X2")); + auto* act_t = ctx.Output("Activated"); + + auto* d_out_t = ctx.Input(framework::GradVarName("Out")); + auto* label_t = ctx.Input("Label"); + + auto& dev = ctx.GetEigenDevice(); + auto d_out = framework::EigenVector::Flatten(*d_out_t); + auto act = framework::EigenVector::Flatten(*act_t); + auto label = framework::EigenVector::Flatten(*label_t); + + // compute d_x1 + if (d_x1_t) { + d_x1_t->mutable_data(ctx.GetPlace()); + auto d_x1 = framework::EigenVector::Flatten(*d_x1_t); + d_x1.device(dev) = -d_out * act * label; + } + // compute d_x2 + if (d_x2_t) { + d_x2_t->mutable_data(ctx.GetPlace()); + auto d_x2 = framework::EigenVector::Flatten(*d_x2_t); + d_x2.device(dev) = d_out * act * label; + } + } +}; +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_margin_rank_loss_op.py b/python/paddle/v2/framework/tests/test_margin_rank_loss_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7118be7cc6f62a596f048de96f88899d81bb7f35 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_margin_rank_loss_op.py @@ -0,0 +1,40 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestMarginRankLossOp(OpTest): + def setUp(self): + self.op_type = "margin_rank_loss" + batch_size = 5 + margin = 0.1 + # labels_{i} = {0, 1.0} or {0, 0.5, 1.0} + label = np.random.randint(0, 2, size=(batch_size, )).astype("float32") + x1 = np.random.random((batch_size, )).astype("float32") + x2 = np.random.random((batch_size, )).astype("float32") + # loss = max(0, -label * (x1 - x2) + margin) + loss = [ + max(0, -label[i] * (x1[i] - x2[i]) + margin) + for i in range(batch_size) + ] + self.attrs = {'margin': margin} + self.inputs = {'Label': label, 'X1': x1, 'X2': x2} + self.outputs = {'Out': loss} + + def test_check_output(self): + self.check_output() + + """ + def test_check_grad(self): + self.check_grad(["X1", "X2"], "Out") + + def test_check_grad_ignore_x1(self): + self.check_grad(["X2"], "Out", no_grad_set=set('X1')) + + def test_check_grad_ignore_x2(self): + self.check_grad(["X1"], "Out", no_grad_set=set('X2')) + """ + + +if __name__ == '__main__': + unittest.main()