提交 f2e7cf21 编写于 作者: C chengduoZH

fix InferShapeContextBase to InferShapeContext

上级 0f1d3af4
...@@ -28,7 +28,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel { ...@@ -28,7 +28,7 @@ class MaxPoolWithIndexOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null."); "X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"), PADDLE_ENFORCE(ctx->HasOutput("Out"),
...@@ -73,7 +73,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel { ...@@ -73,7 +73,7 @@ class MaxPoolWithIndexOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
protected: protected:
void InferShape(framework::InferShapeContextBase *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input(X@GRAD) should not be null."); "Input(X@GRAD) should not be null.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册