diff --git a/paddle/fluid/operators/temporal_shift_op.cc b/paddle/fluid/operators/temporal_shift_op.cc index 3b7d90b795b45d97dfdbe90f7e37ea28b942f2a0..f2a8ae9a411c34ce3f18884d6c2eab45eae5d5ab 100644 --- a/paddle/fluid/operators/temporal_shift_op.cc +++ b/paddle/fluid/operators/temporal_shift_op.cc @@ -26,10 +26,10 @@ class TemporalShiftOp : public framework::OperatorWithKernel { protected: void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of TemporalShiftOp should not be null."); - PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of TemporalShiftOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, + "Input(X) of TemporalShiftOp should not be null."); + PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true, + "Output(Out) of TemporalShiftOp should not be null."); auto dim_x = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(dim_x.size(), 4, @@ -38,9 +38,10 @@ class TemporalShiftOp : public framework::OperatorWithKernel { int seg_num = ctx->Attrs().Get("seg_num"); float shift_ratio = ctx->Attrs().Get("shift_ratio"); PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0."); - PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5, - "Attr(shift_ratio) should be greater than 0 and less " - "than 0.5."); + PADDLE_ENFORCE_GT(shift_ratio, 0., + "Attr(shift_ratio) should be greater than 0"); + PADDLE_ENFORCE_LT(shift_ratio, 0.5, + "Attr(shift_ratio) should be less than 0.5"); if (ctx->IsRuntime()) { PADDLE_ENFORCE_EQ(