未验证 提交 63232e49 编写于 作者: K Kaipeng Deng 提交者: GitHub

Update OP_INOUT_CHECK (#23757)

* update NotFound -> OP_INOUT_CHECK: grid_sampler, kldiv_loss, spectral_norm, temporal_shift. test=develop
上级 9e85d023
......@@ -28,16 +28,9 @@ class GridSampleOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of GridSampleOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Grid"), true,
platform::errors::NotFound(
"Input(Grid) of GridSampleOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Output"), true,
platform::errors::NotFound(
"Output(Output) of GridSampleOp should not be null."));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GridSampler");
OP_INOUT_CHECK(ctx->HasInput("Grid"), "Input", "Grid", "GridSampler");
OP_INOUT_CHECK(ctx->HasOutput("Output"), "Output", "Output", "GridSampler");
auto x_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
......
......@@ -23,15 +23,9 @@ class KLDivLossOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of KLDivLossOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("Target"), true,
platform::errors::NotFound(
"Input(Target) of KLDivLossOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Loss"), true,
platform::errors::NotFound(
"Output(Loss) of KLDivLossOp should not be null."));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "KLDivLoss");
OP_INOUT_CHECK(ctx->HasInput("Target"), "Input", "Target", "KLDivLoss");
OP_INOUT_CHECK(ctx->HasOutput("Loss"), "Output", "Loss", "KLDivLoss");
auto dim_x = ctx->GetInputDim("X");
auto dim_target = ctx->GetInputDim("Target");
......@@ -135,15 +129,10 @@ class KLDivLossOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
platform::errors::NotFound("Input(X) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("Target"), true,
platform::errors::NotFound("Input(Target) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Loss")), true,
platform::errors::NotFound("Input(Loss@GRAD) should not be null"));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "KLDivLossGrad");
OP_INOUT_CHECK(ctx->HasInput("Target"), "Input", "Target", "KLDivLossGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Loss")), "Input",
"Loss@GRAD", "KLDivLossGrad");
auto dim_x = ctx->GetInputDim("X");
if (ctx->HasOutput(framework::GradVarName("X"))) {
ctx->SetOutputDim(framework::GradVarName("X"), dim_x);
......
......@@ -26,19 +26,10 @@ class SpectralNormOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Weight"), true,
platform::errors::NotFound(
"Input(Weight) of SpectralNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("U"), true,
platform::errors::NotFound(
"Input(U) of SpectralNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasInput("V"), true,
platform::errors::NotFound(
"Input(V) of SpectralNormOp should not be null."));
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of SpectralNormOp should not be null."));
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm");
auto dim_weight = ctx->GetInputDim("Weight");
auto rank_weight = dim_weight.size();
......@@ -220,15 +211,13 @@ class SpectralNormOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("Weight"), true,
platform::errors::NotFound("Input(Weight) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("U"), true,
platform::errors::NotFound("Input(U) should not be null"));
PADDLE_ENFORCE_EQ(
ctx->HasInput("V"), true,
platform::errors::NotFound("Input(V) should not be null"));
OP_INOUT_CHECK(ctx->HasInput("Weight"), "Input", "Weight",
"SpectralNormGrad");
OP_INOUT_CHECK(ctx->HasInput("U"), "Input", "U", "SpectralNormGrad");
OP_INOUT_CHECK(ctx->HasInput("V"), "Input", "V", "SpectralNormGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
"Out@GRAD", "SpectralNormGrad");
PADDLE_ENFORCE_EQ(
ctx->HasInput(framework::GradVarName("Out")), true,
platform::errors::NotFound("Input(Out@GRAD) should not be null"));
......
......@@ -26,13 +26,8 @@ class TemporalShiftOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::NotFound(
"Input(X) of TemporalShiftOp should not be null."));
PADDLE_ENFORCE_EQ(
ctx->HasOutput("Out"), true,
platform::errors::NotFound(
"Output(Out) of TemporalShiftOp should not be null."));
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "SpectralNorm");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "SpectralNorm");
auto dim_x = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(dim_x.size(), 4,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册