From 2848cb791e17e7254c97e8280383279f1dd96a33 Mon Sep 17 00:00:00 2001 From: Kaipeng Deng Date: Fri, 16 Aug 2019 10:20:28 +0800 Subject: [PATCH] fix temporal_shift OP PADDLE_ENFORCE. test=develop (#19161) * fix temporal_shift OP PADDLE_ENFORCE. test=develop * fix HasInput/HasOutpu ENFORECE. test=develop --- paddle/fluid/operators/temporal_shift_op.cc | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index 3b7d90b795b..f2a8ae9a411 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -26,10 +26,10 @@ class TemporalShiftOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of TemporalShiftOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of TemporalShiftOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of TemporalShiftOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of TemporalShiftOp should not be null."); auto dim_x = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, @@ -38,9 +38,10 @@ class TemporalShiftOp : public framework::OperatorWithKernel { int seg_num = ctx->Attrs().Get("seg_num"); float shift_ratio = ctx->Attrs().Get("shift_ratio"); PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0."); - PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5, - "Attr(shift_ratio) should be greater than 0 and less " - "than 0.5."); + PADDLE_ENFORCE_GT(shift_ratio, 0., + "Attr(shift_ratio) should be greater than 0"); + PADDLE_ENFORCE_LT(shift_ratio, 0.5, + "Attr(shift_ratio) should be less than 0.5"); if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ( -- GitLab