From 4db50fbcddf9ca592c4795b37d2f0d023fbba652 Mon Sep 17 00:00:00 2001 From: Yibing Liu Date: Thu, 28 Sep 2017 17:27:39 +0800 Subject: [PATCH] adapt to the new infershape interface --- paddle/operators/margin_rank_loss_op.cc | 68 ++++++++++--------------- 1 file changed, 26 insertions(+), 42 deletions(-) diff --git a/paddle/operators/margin_rank_loss_op.cc b/paddle/operators/margin_rank_loss_op.cc index 8d62dbb4c..3f94f73fe 100644 --- a/paddle/operators/margin_rank_loss_op.cc +++ b/paddle/operators/margin_rank_loss_op.cc @@ -22,28 +22,21 @@ class MarginRankLossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { + void InferShape(framework::InferShapeContextBase *ctx) const override { // input check - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"), - "Output(X2) shouldn't be null."); - auto label_dims = ctx.Input("Label")->dims(); - auto x1_dims = ctx.Input("X1")->dims(); - auto x2_dims = ctx.Input("X2")->dims(); - PADDLE_ENFORCE((label_dims == x1_dims) && (x1_dims == x2_dims) && - (label_dims.size() == 2) && (label_dims[1] == 1), - "All inputs must be vector with the same size."); - auto act_t = ctx.Output("Activated"); - auto out_t = ctx.Output("Out"); - if (act_t) { - act_t->Resize(label_dims); - } - if (out_t) { - out_t->Resize(label_dims); - } + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) shouldn't be null."); + auto label_dims = ctx->GetInputDim("Label"); + auto x1_dims = ctx->GetInputDim("X1"); + auto x2_dims = ctx->GetInputDim("X2"); + PADDLE_ENFORCE( + (label_dims == x1_dims) && (x1_dims == x2_dims) && + (label_dims.size() == 2) && (label_dims[1] == 1), + "All inputs must be 2-D tensor with shape [batch_size x 1]."); + ctx->SetOutputDim("Activated", label_dims); + ctx->SetOutputDim("Out", label_dims); } }; @@ -71,7 +64,7 @@ class MarginRankLossOpMaker : public framework::OpProtoAndCheckerMaker { .AsIntermediate(); AddOutput("Out", "(2-D tensor with shape [batch_size x 1])" - "The output loss of MarginRankLoss operator"); + "The output loss of MarginRankLoss operator."); AddComment(R"DOC( MarginRankLoss operator measures the loss given a pair of input {`X1`, `X2`} @@ -96,26 +89,17 @@ class MarginRankLossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(const framework::InferShapeContext &ctx) const override { - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Label"), - "Input(Label) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X1"), "Input(X1) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X2"), "Input(X2) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), - "Input(Out@GRAD) shouldn't be null."); - PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Activated"), - "Intermediate(Activated) shouldn't be null."); - auto dims = ctx.Input("X1")->dims(); - auto *x1_grad = - ctx.Output(framework::GradVarName("X1")); - auto *x2_grad = - ctx.Output(framework::GradVarName("X2")); - if (x1_grad) { - x1_grad->Resize(dims); - } - if (x2_grad) { - x2_grad->Resize(dims); - } + void InferShape(framework::InferShapeContextBase *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("X1"), "Input(X1) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("X2"), "Input(X2) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), + "Input(Out@GRAD) shouldn't be null."); + PADDLE_ENFORCE(ctx->HasInput("Activated"), + "Intermediate(Activated) shouldn't be null."); + auto dims = ctx->GetInputDim("Label"); + ctx->SetOutputDim(framework::GradVarName("X1"), dims); + ctx->SetOutputDim(framework::GradVarName("X2"), dims); } }; -- GitLab