未验证 提交 2848cb79 编写于 作者: K Kaipeng Deng 提交者: GitHub

fix temporal_shift OP PADDLE_ENFORCE. test=develop (#19161)

* fix temporal_shift OP PADDLE_ENFORCE. test=develop

* fix HasInput/HasOutpu ENFORECE. test=develop
上级 2f8c7e02
......@@ -26,10 +26,10 @@ class TemporalShiftOp : public framework::OperatorWithKernel {
protected:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of TemporalShiftOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of TemporalShiftOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of TemporalShiftOp should not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of TemporalShiftOp should not be null.");
auto dim_x = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(dim_x.size(), 4,
......@@ -38,9 +38,10 @@ class TemporalShiftOp : public framework::OperatorWithKernel {
int seg_num = ctx->Attrs().Get<int>("seg_num");
float shift_ratio = ctx->Attrs().Get<float>("shift_ratio");
PADDLE_ENFORCE_GT(seg_num, 0, "Attr(seg_num) should be greater than 0.");
PADDLE_ENFORCE(shift_ratio > 0 || shift_ratio < .5,
"Attr(shift_ratio) should be greater than 0 and less "
"than 0.5.");
PADDLE_ENFORCE_GT(shift_ratio, 0.,
"Attr(shift_ratio) should be greater than 0");
PADDLE_ENFORCE_LT(shift_ratio, 0.5,
"Attr(shift_ratio) should be less than 0.5");
if (ctx->IsRuntime()) {
PADDLE_ENFORCE_EQ(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册