From 63232e4941ee90ed1470f2488a13ea8b88051160 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Sun, 12 Apr 2020 19:45:40 +0800 Subject: [PATCH] Update OP_INOUT_CHECK (#23757) * update NotFound -> OP_INOUT_CHECK: grid_sampler, kldiv_loss, spectral_norm, temporal_shift. test=develop --- paddle/fluid/operators/grid_sampler_op.cc | 13 ++------ paddle/fluid/operators/kldiv_loss_op.cc | 25 +++++----------- paddle/fluid/operators/spectral_norm_op.cc | 33 +++++++-------------- paddle/fluid/operators/temporal_shift_op.cc | 9 ++---- 4 files changed, 23 insertions(+), 57 deletions(-) diff --git a/paddle/fluid/operators/grid_sampler_op.cc b/paddle/fluid/operators/grid_sampler_op.cc index ea0fc05bbd8..5be49037964 100644 --- a/paddle/fluid/operators/grid_sampler_op.cc +++ b/paddle/fluid/operators/grid_sampler_op.cc @@ -28,16 +28,9 @@ class GridSampleOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::NotFound( - "Input(X) of GridSampleOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Grid"), true, - platform::errors::NotFound( - "Input(Grid) of GridSampleOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Output"), true, - platform::errors::NotFound( - "Output(Output) of GridSampleOp should not be null.")); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GridSampler"); + OP_INOUT_CHECK(ctx->HasInput("Grid"), "Input", "Grid", "GridSampler"); + OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "GridSampler"); auto x_dims = ctx->GetInputDim("X"); auto grid_dims = ctx->GetInputDim("Grid"); diff --git a/paddle/fluid/operators/kldiv_loss_op.cc b/paddle/fluid/operators/kldiv_loss_op.cc index ad333242396..7286e2e6d31 100644 --- a/paddle/fluid/operators/kldiv_loss_op.cc +++ b/paddle/fluid/operators/kldiv_loss_op.cc @@ -23,15 +23,9 @@ class KLDivLossOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::NotFound( - "Input(X) of KLDivLossOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("Target"), true, - platform::errors::NotFound( - "Input(Target) of KLDivLossOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Loss"), true, - platform::errors::NotFound( - "Output(Loss) of KLDivLossOp should not be null.")); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "KLDivLoss"); + OP_INOUT_CHECK(ctx->HasInput("Target"), "Input", "Target", "KLDivLoss"); + OP_INOUT_CHECK(ctx->HasOutput("Loss"), "Output", "Loss", "KLDivLoss"); auto dim_x = ctx->GetInputDim("X"); auto dim_target = ctx->GetInputDim("Target"); @@ -135,15 +129,10 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("X"), true, - platform::errors::NotFound("Input(X) should not be null")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("Target"), true, - platform::errors::NotFound("Input(Target) should not be null")); - PADDLE_ENFORCE_EQ( - ctx->HasInput(framework::GradVarName("Loss")), true, - platform::errors::NotFound("Input(Loss@GRAD) should not be null")); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "KLDivLossGrad"); + OP_INOUT_CHECK(ctx->HasInput("Target"), "Input", "Target", "KLDivLossGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Loss")), "Input", + "Loss@GRAD", "KLDivLossGrad"); auto dim_x = ctx->GetInputDim("X"); if (ctx->HasOutput(framework::GradVarName("X"))) { ctx->SetOutputDim(framework::GradVarName("X"), dim_x); diff --git a/paddle/fluid/operators/spectral_norm_op.cc b/paddle/fluid/operators/spectral_norm_op.cc index a49c25a8fcb..71e5c978d79 100644 --- a/paddle/fluid/operators/spectral_norm_op.cc +++ b/paddle/fluid/operators/spectral_norm_op.cc @@ -26,19 +26,10 @@ class SpectralNormOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("Weight"), true, - platform::errors::NotFound( - "Input(Weight) of SpectralNormOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("U"), true, - platform::errors::NotFound( - "Input(U) of SpectralNormOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasInput("V"), true, - platform::errors::NotFound( - "Input(V) of SpectralNormOp should not be null.")); - PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, - platform::errors::NotFound( - "Output(Out) of SpectralNormOp should not be null.")); + OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "SpectralNorm"); + OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNorm"); + OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNorm"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm"); auto dim_weight = ctx->GetInputDim("Weight"); auto rank_weight = dim_weight.size(); @@ -220,15 +211,13 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ( - ctx->HasInput("Weight"), true, - platform::errors::NotFound("Input(Weight) should not be null")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("U"), true, - platform::errors::NotFound("Input(U) should not be null")); - PADDLE_ENFORCE_EQ( - ctx->HasInput("V"), true, - platform::errors::NotFound("Input(V) should not be null")); + OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", + "SpectralNormGrad"); + OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNormGrad"); + OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNormGrad"); + OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input", + "Out@GRAD", "SpectralNormGrad"); + PADDLE_ENFORCE_EQ( ctx->HasInput(framework::GradVarName("Out")), true, platform::errors::NotFound("Input(Out@GRAD) should not be null")); diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index 819cac3ee4d..2e87447ed16 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -26,13 +26,8 @@ class TemporalShiftOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, - platform::errors::NotFound( - "Input(X) of TemporalShiftOp should not be null.")); - PADDLE_ENFORCE_EQ( - ctx->HasOutput("Out"), true, - platform::errors::NotFound( - "Output(Out) of TemporalShiftOp should not be null.")); + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SpectralNorm"); + OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm"); auto dim_x = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, -- GitLab