diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index af8989dc4f0fdccdd0c5fbcf01ca20ef6f9fdc40..1e9ace99876e00b9f733da3e66f37dc0dc2f8cad 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -309,7 +309,7 @@ template <> std::vector ExecutionContext::MultiOutput( const std::string& name) const; -class CompileTimeInferShapeContext : public InferShapeContextBase { +class CompileTimeInferShapeContext : public InferShapeContext { public: CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block) : op_(op), block_(block) {} @@ -405,7 +405,7 @@ class CompileTimeInferShapeContext : public InferShapeContextBase { const BlockDescBind& block_; }; -class RuntimeInferShapeContext : public InferShapeContextBase { +class RuntimeInferShapeContext : public InferShapeContext { public: RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope) : op_(op), scope_(scope) {} @@ -603,7 +603,7 @@ class OperatorWithKernel : public OperatorBase { }); } - virtual void InferShape(InferShapeContextBase* ctx) const = 0; + virtual void InferShape(InferShapeContext* ctx) const = 0; protected: // indicate kernel DataType by input data. Defaultly all input data must be diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index a0c17b41f27d9ec9a0f8e80576a052617919b000..a02f4668bca2360995cc05206f7f97e027db0907 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -113,7 +113,7 @@ class OpWithKernelTest : public OperatorWithKernel { using OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override {} DataType IndicateDataType(const ExecutionContext& ctx) const override { return DataType::FP32; } diff --git a/paddle/framework/shape_inference.h b/paddle/framework/shape_inference.h index 74e0371e328114294d7f85932b1e551c21ff5b97..64aab16ae54d34fd614add348c7c420b4a8f771d 100644 --- a/paddle/framework/shape_inference.h +++ b/paddle/framework/shape_inference.h @@ -20,11 +20,11 @@ namespace paddle { namespace framework { // TODO(longfei): Once after both CompileTimeInferShapeContext and -// RuntimeInferShapeContext get merged, we can rename InferShapeContextBase into +// RuntimeInferShapeContext get merged, we can rename InferShapeContext into // InferShapeContext so to replace the current InferShapeContext. -class InferShapeContextBase { +class InferShapeContext { public: - virtual ~InferShapeContextBase() {} + virtual ~InferShapeContext() {} virtual bool HasInput(const std::string &name) const = 0; virtual bool HasOutput(const std::string &name) const = 0; diff --git a/paddle/operators/accuracy_op.cc b/paddle/operators/accuracy_op.cc index 82010bfb53e58a0836c99c353590f4e32e25ac4a..c5fb113e0f4455ec852fd48542ea2917398df5f3 100644 --- a/paddle/operators/accuracy_op.cc +++ b/paddle/operators/accuracy_op.cc @@ -22,7 +22,7 @@ class AccuracyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Inference"), "Input(Inference) of AccuracyOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Label"), diff --git a/paddle/operators/activation_op.cc b/paddle/operators/activation_op.cc index 66e9d2c40138c26975f07cb544e54de6f00d6b09..5df875cd61d6cc4ac820ccc4f03bd32b3a03936f 100644 --- a/paddle/operators/activation_op.cc +++ b/paddle/operators/activation_op.cc @@ -22,7 +22,7 @@ class ActivationOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputDim("Y", ctx->GetInputDim("X")); ctx->ShareLoD("X", /*->*/ "Y"); } @@ -33,7 +33,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y")); } }; diff --git a/paddle/operators/adadelta_op.cc b/paddle/operators/adadelta_op.cc index bd8c93b4a193c169a4472ff20efca779ddc5c804..cf1bca1658f3fb4cd14465791a0f94865871a120 100644 --- a/paddle/operators/adadelta_op.cc +++ b/paddle/operators/adadelta_op.cc @@ -22,7 +22,7 @@ class AdadeltaOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of AdadeltaOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), diff --git a/paddle/operators/adagrad_op.cc b/paddle/operators/adagrad_op.cc index ea2ff3c50306c0b0db4c769129b6e7f2bab3a7ee..a17747efb79d76a6beffb27bfc03ee1b62c5618d 100644 --- a/paddle/operators/adagrad_op.cc +++ b/paddle/operators/adagrad_op.cc @@ -22,7 +22,7 @@ class AdagradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of AdagradOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), diff --git a/paddle/operators/clip_op.cc b/paddle/operators/clip_op.cc index b3dd060fd725fc9056b25e4affd82fdb345e77f7..3e9b0d82ba06a918d52124e791715277640fd4ff 100644 --- a/paddle/operators/clip_op.cc +++ b/paddle/operators/clip_op.cc @@ -22,7 +22,7 @@ class ClipOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ClipOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -61,7 +61,7 @@ class ClipOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); diff --git a/paddle/operators/concat_op.cc b/paddle/operators/concat_op.cc index 1ffa02c8f94c01a385d3ba376c1fd0dc3c1bd372..235c4449ace79e87d209615199d6628f269961a9 100644 --- a/paddle/operators/concat_op.cc +++ b/paddle/operators/concat_op.cc @@ -24,7 +24,7 @@ class ConcatOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL, "Inputs(X) of ConcatOp should be empty.") PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -83,7 +83,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X")); } }; diff --git a/paddle/operators/conv2d_op.cc b/paddle/operators/conv2d_op.cc index 5cc82944bb6b9a4fc5cd94cf2233ab84fc105fe7..6325d4248f10ea8a12ae5398d9fe0e579db3f7ae 100644 --- a/paddle/operators/conv2d_op.cc +++ b/paddle/operators/conv2d_op.cc @@ -27,7 +27,7 @@ class Conv2DOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Input"), "Input(Input) of Conv2DOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Filter"), @@ -106,7 +106,7 @@ class Conv2DOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { auto in_dims = ctx->GetInputDim("Input"); auto filter_dims = ctx->GetInputDim("Filter"); if (ctx->HasOutput(framework::GradVarName("Input"))) { diff --git a/paddle/operators/cos_sim_op.cc b/paddle/operators/cos_sim_op.cc index 040546f1a6fe1af6d17a5e363a11d27de88d03c2..2b4c4b9c45d22e9d0b124145b55cfaeeaf1583d5 100644 --- a/paddle/operators/cos_sim_op.cc +++ b/paddle/operators/cos_sim_op.cc @@ -24,7 +24,7 @@ class CosSimOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { // notnull check PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of CosSimOp should not be null."); @@ -98,7 +98,7 @@ class CosSimOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { // notnull check PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null."); diff --git a/paddle/operators/crop_op.cc b/paddle/operators/crop_op.cc index 9b2305e90e85a6f39d4c584a3251b25f67e81aca..a1424993cc3076f2755164393c5429d79fc1ee78 100644 --- a/paddle/operators/crop_op.cc +++ b/paddle/operators/crop_op.cc @@ -25,7 +25,7 @@ class CropOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of CropOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -115,7 +115,7 @@ class CropOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); diff --git a/paddle/operators/cross_entropy_op.cc b/paddle/operators/cross_entropy_op.cc index 4b67887f3638f32a89d1a4fd1316c0596b444629..708e80e96a0f007be1594d8b4781d1296d6b398d 100644 --- a/paddle/operators/cross_entropy_op.cc +++ b/paddle/operators/cross_entropy_op.cc @@ -22,7 +22,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); @@ -60,7 +60,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc index a669b5cf00f1f4ad351486e2977bf8a76aa5bf62..708ccfa0bfc163ea75fcd9dbcb5ef6140b4b5f36 100644 --- a/paddle/operators/dropout_op.cc +++ b/paddle/operators/dropout_op.cc @@ -24,7 +24,7 @@ class DropoutOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE_GE(ctx->Attrs().Get("dropout_prob"), 0); PADDLE_ENFORCE_LE(ctx->Attrs().Get("dropout_prob"), 1); @@ -70,7 +70,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE_EQ(ctx->Attrs().Get("is_training"), 1, "GradOp is only callable when is_training is true"); diff --git a/paddle/operators/elementwise_op.h b/paddle/operators/elementwise_op.h index 3082f37422faa990bbf03c8a1a87b025d481a290..66f1910a4738f737c7581e17729c201a7e3eaeae 100644 --- a/paddle/operators/elementwise_op.h +++ b/paddle/operators/elementwise_op.h @@ -25,7 +25,7 @@ class ElementwiseOp : public framework::OperatorWithKernel { protected: using Tensor = framework::Tensor; - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of elementwise op should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), @@ -106,7 +106,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel { using Tensor = framework::Tensor; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), diff --git a/paddle/operators/fill_zeros_like_op.cc b/paddle/operators/fill_zeros_like_op.cc index e164de6584e7350283781019cc74118c2d13646e..4c70b9a36bb8f3a2184a8f2ac738378e9d3271e2 100644 --- a/paddle/operators/fill_zeros_like_op.cc +++ b/paddle/operators/fill_zeros_like_op.cc @@ -22,7 +22,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of FillZerosLikeOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Y"), diff --git a/paddle/operators/gather_op.cc b/paddle/operators/gather_op.cc index fe305337cbebd7c679ae1b8ee8aa2740472ee109..fb99c6c01670fcbacd012573b0679ac755638c57 100644 --- a/paddle/operators/gather_op.cc +++ b/paddle/operators/gather_op.cc @@ -23,7 +23,7 @@ class GatherOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of GatherOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Index"), @@ -51,7 +51,7 @@ class GatherGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } diff --git a/paddle/operators/gaussian_random_op.cc b/paddle/operators/gaussian_random_op.cc index 5cd2c7d2c066cd31e2d38a3c0d682f02339b4d59..ca7fb385057dda4ee6f7dea7edc21d9ebbf6b48e 100644 --- a/paddle/operators/gaussian_random_op.cc +++ b/paddle/operators/gaussian_random_op.cc @@ -43,7 +43,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of GaussianRandomOp should not be null."); auto dims = ctx->Attrs().Get>("dims"); diff --git a/paddle/operators/lookup_table_op.cc b/paddle/operators/lookup_table_op.cc index 929008fbcbe03bd6591b0a02252b343c46d00b8f..3f8d4ab85756ec007004e397a55ea60e2fa704ae 100644 --- a/paddle/operators/lookup_table_op.cc +++ b/paddle/operators/lookup_table_op.cc @@ -22,7 +22,7 @@ class LookupTableOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("W"), "Input(W) of LookupTableOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Ids"), @@ -70,7 +70,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { auto table_dims = ctx->GetInputDim("W"); ctx->SetOutputDim(framework::GradVarName("W"), table_dims); } diff --git a/paddle/operators/lstm_unit_op.cc b/paddle/operators/lstm_unit_op.cc index dad56731de80518e3bf9d2ec1ffdac9cb6bc92f0..13a45ec246988558b0c090ab14e627b5b0e44532 100644 --- a/paddle/operators/lstm_unit_op.cc +++ b/paddle/operators/lstm_unit_op.cc @@ -22,7 +22,7 @@ class LstmUnitOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null."); PADDLE_ENFORCE(ctx->HasInput("C_prev"), "Input(C_prev) of LSTM should not be null."); @@ -77,7 +77,7 @@ class LstmUnitGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("C")), "Input(C@GRAD) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("H")), diff --git a/paddle/operators/mean_op.cc b/paddle/operators/mean_op.cc index 2332c9546b037c94a5a6d30319abda8e23c2b3bb..441543049faf96c8fbc14e84448a348c3287c1e2 100644 --- a/paddle/operators/mean_op.cc +++ b/paddle/operators/mean_op.cc @@ -22,7 +22,7 @@ class MeanOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MeanOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -47,7 +47,7 @@ class MeanGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); } }; diff --git a/paddle/operators/minus_op.cc b/paddle/operators/minus_op.cc index 7057dcbd6e375adef57d17a13afdfade67e938b6..d7fd2f901b766c3105a5ea8847fb8a3fe456a945 100644 --- a/paddle/operators/minus_op.cc +++ b/paddle/operators/minus_op.cc @@ -26,7 +26,7 @@ class MinusOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MinusOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), diff --git a/paddle/operators/modified_huber_loss_op.cc b/paddle/operators/modified_huber_loss_op.cc index 84212a2b3be1ac3664ebd77c7a0ae4d86abad3a0..6522327fdcf646a60df76909d049341b5a9021c9 100644 --- a/paddle/operators/modified_huber_loss_op.cc +++ b/paddle/operators/modified_huber_loss_op.cc @@ -22,7 +22,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); @@ -74,7 +74,7 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); PADDLE_ENFORCE(ctx->HasInput("IntermediateVal"), diff --git a/paddle/operators/mul_op.cc b/paddle/operators/mul_op.cc index 3c8fe04d2edeccc0e0d55aa2a28d71085ccf5145..ec0683d8875a945746b676a1f859e614af557441 100644 --- a/paddle/operators/mul_op.cc +++ b/paddle/operators/mul_op.cc @@ -24,7 +24,7 @@ class MulOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -97,7 +97,7 @@ class MulOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), diff --git a/paddle/operators/multiplex_op.cc b/paddle/operators/multiplex_op.cc index a069127a19a1d0ba4eaa2b3450a1c46262ace3ed..a86685b6dde4761cf74f9521bd9609b0864b9bdf 100644 --- a/paddle/operators/multiplex_op.cc +++ b/paddle/operators/multiplex_op.cc @@ -24,7 +24,7 @@ class MultiplexOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) shouldn't be null."); PADDLE_ENFORCE(!ctx->Inputs("X").empty(), "MultiInput(X) shouldn't be empty."); @@ -90,7 +90,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(!ctx->Inputs("X").empty(), "Input(X) should not be null."); PADDLE_ENFORCE(!ctx->Outputs(framework::GradVarName("X")).empty(), "Output(X@Grad) should not be null."); diff --git a/paddle/operators/pad_op.cc b/paddle/operators/pad_op.cc index 15aa05f26610be14e4c37be35137a259e00eb947..2f26ada85e477a6e5cd24c881f6369548331e8d0 100644 --- a/paddle/operators/pad_op.cc +++ b/paddle/operators/pad_op.cc @@ -24,7 +24,7 @@ class PadOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of PadOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of PadOp should not be null."); @@ -98,7 +98,7 @@ class PadOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index c29f51f05613832c838400eb114465c81290ea58..ba3b5ed2075ceca284b49ecddb90ba5950b820c3 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -27,7 +27,7 @@ class PoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -74,7 +74,7 @@ class PoolOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "X(Input) of Pooling should not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), diff --git a/paddle/operators/prelu_op.cc b/paddle/operators/prelu_op.cc index 1692464f2833a59243ccc1598422180262a59282..166fe2682422c61a264c88b8683752ee323e940e 100644 --- a/paddle/operators/prelu_op.cc +++ b/paddle/operators/prelu_op.cc @@ -26,7 +26,7 @@ class PReluOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null"); PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1, @@ -63,7 +63,7 @@ class PReluGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); diff --git a/paddle/operators/rank_loss_op.cc b/paddle/operators/rank_loss_op.cc index 1ba22006f27abc963e7f161636a964863513a40c..e0abbc4db1723e297c313d50fc1f9cbb77acb9f3 100644 --- a/paddle/operators/rank_loss_op.cc +++ b/paddle/operators/rank_loss_op.cc @@ -25,7 +25,7 @@ class RankLossOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { // input check PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null"); PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null"); @@ -90,7 +90,7 @@ class RankLossGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null."); diff --git a/paddle/operators/reduce_op.cc b/paddle/operators/reduce_op.cc index 3ef443d1c7f475cbd578078db02fb5e0d500d060..12081ee6f0d0ff67b3c22ca87556407961fd0486 100644 --- a/paddle/operators/reduce_op.cc +++ b/paddle/operators/reduce_op.cc @@ -24,7 +24,7 @@ class ReduceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ReduceOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -58,7 +58,7 @@ class ReduceGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null."); diff --git a/paddle/operators/reshape_op.cc b/paddle/operators/reshape_op.cc index a3c3fa2716ad9f6487e3eff2d98b2c76d964ddef..3cd54930a0919b7c6aad624b442eb15167b4c354 100644 --- a/paddle/operators/reshape_op.cc +++ b/paddle/operators/reshape_op.cc @@ -26,7 +26,7 @@ class ReshapeOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { // input check PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ReshapeOp should not be null."); @@ -94,7 +94,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) shouldn't be null."); diff --git a/paddle/operators/rmsprop_op.cc b/paddle/operators/rmsprop_op.cc index 8f61c7fdda9f80c69745a9bc4569fcbc099630aa..ada6f2bc3cd5de0424139144d8c0696158c319d9 100644 --- a/paddle/operators/rmsprop_op.cc +++ b/paddle/operators/rmsprop_op.cc @@ -22,7 +22,7 @@ class RmspropOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of RmspropOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("MeanSquare"), diff --git a/paddle/operators/scale_op.cc b/paddle/operators/scale_op.cc index e225aecc270bc17c535c10253c970b888c42e5d3..ac297da6b72250eaa5eb6ea2177fbcc7a6c1ec56 100644 --- a/paddle/operators/scale_op.cc +++ b/paddle/operators/scale_op.cc @@ -26,7 +26,7 @@ class ScaleOp : public framework::OperatorWithKernel { : OperatorWithKernel(type, inputs, outputs, attrs) {} protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of ScaleOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), diff --git a/paddle/operators/scatter_op.cc b/paddle/operators/scatter_op.cc index d15ba151539987c133ac57e102df53551483c6dd..fbea01a8db007a5b8cfe17f545f1cce6ecff3afb 100644 --- a/paddle/operators/scatter_op.cc +++ b/paddle/operators/scatter_op.cc @@ -23,7 +23,7 @@ class ScatterOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Ref"), "Input(Ref) of ScatterOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Index"), @@ -60,7 +60,7 @@ class ScatterGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { ctx->SetOutputDim(framework::GradVarName("Updates"), ctx->GetInputDim("Updates")); ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref")); diff --git a/paddle/operators/sequence_pool_op.cc b/paddle/operators/sequence_pool_op.cc index bc4af2f70427e684dfb531b8c61d68f28ae20794..06c00d31eac9ff09084db951ce2929850f9d14c2 100644 --- a/paddle/operators/sequence_pool_op.cc +++ b/paddle/operators/sequence_pool_op.cc @@ -22,7 +22,7 @@ class SequencePoolOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SequencePoolOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -74,7 +74,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Gradient of Out should not be null."); PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null."); diff --git a/paddle/operators/sequence_softmax_op.cc b/paddle/operators/sequence_softmax_op.cc index 621779ab6133f56a43fb2d20c814ebed8762ea7d..ea217ba4597c24ec323c6509f097874d977fcf81 100644 --- a/paddle/operators/sequence_softmax_op.cc +++ b/paddle/operators/sequence_softmax_op.cc @@ -22,7 +22,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SequenceSoftmaxOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), @@ -67,7 +67,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Out"), "Input(Out) of SequenceSoftmaxGradOp should not be null."); PADDLE_ENFORCE( diff --git a/paddle/operators/sgd_op.cc b/paddle/operators/sgd_op.cc index 31d491f130e36289f7dd3dc18710d35d5ba6ab34..2a6a162a02507fa42c5ecbba59384f9c32aba7a9 100644 --- a/paddle/operators/sgd_op.cc +++ b/paddle/operators/sgd_op.cc @@ -22,7 +22,7 @@ class SGDOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Param"), "Input(Param) of SGDOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Grad"), diff --git a/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc b/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc index ede458e01147ab70444c4d158973cfb0818b9bdd..b6653e1cc754bd626f04c28c42ce59f83c2fc87f 100644 --- a/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc +++ b/paddle/operators/sigmoid_cross_entropy_with_logits_op.cc @@ -24,7 +24,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Labels"), "Input(Labels) should be not null."); @@ -53,7 +53,7 @@ class SigmoidCrossEntropyWithLogitsGradOp using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Labels"), "Input(Labels) should be not null."); diff --git a/paddle/operators/smooth_l1_loss_op.cc b/paddle/operators/smooth_l1_loss_op.cc index 2d197e3b1b763fa87939623d47728aab3bff7cd1..91391dc945be3d3d94966339bf9232eb21dcc38f 100644 --- a/paddle/operators/smooth_l1_loss_op.cc +++ b/paddle/operators/smooth_l1_loss_op.cc @@ -22,7 +22,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized."); PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized."); @@ -94,7 +94,7 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { auto in_dims = ctx->GetInputDim("X"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); diff --git a/paddle/operators/softmax_op.cc b/paddle/operators/softmax_op.cc index e353afee3e10247fbd5c7f4282c366cd1bc39552..4c131ed44de1c1567193c7cabff6ddf2dc63d9be 100644 --- a/paddle/operators/softmax_op.cc +++ b/paddle/operators/softmax_op.cc @@ -22,7 +22,7 @@ class SoftmaxOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SoftmaxOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Y"), @@ -69,7 +69,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null."); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), "Input(Y@GRAD) should be not null."); diff --git a/paddle/operators/softmax_with_cross_entropy_op.cc b/paddle/operators/softmax_with_cross_entropy_op.cc index 42c1ba6fdf1351c43ef78efaaf05c54acb54ce94..5431a1657c300d9c60100616ca7ea2e1196824ec 100644 --- a/paddle/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/operators/softmax_with_cross_entropy_op.cc @@ -83,7 +83,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("Logits"), "Input(Logits) should be not null."); PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); @@ -128,7 +128,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")), "Input(Loss@Grad) should not be null."); PADDLE_ENFORCE(ctx->HasInput("Softmax"), diff --git a/paddle/operators/split_op.cc b/paddle/operators/split_op.cc index 5f4b5539affef6fe1d3c4d15fff77d983b5e107f..d5dd4df2a2a6f424421ae7f117c26acff08d37d0 100644 --- a/paddle/operators/split_op.cc +++ b/paddle/operators/split_op.cc @@ -24,7 +24,7 @@ class SplitOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SplitOp should not be null."); PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL, diff --git a/paddle/operators/squared_l2_distance_op.cc b/paddle/operators/squared_l2_distance_op.cc index 5a0cb596008a98aacf5e7b5ff70307ea1b8508e6..cce4e527c36bc0f4735165ff88ebba85b00eb898 100644 --- a/paddle/operators/squared_l2_distance_op.cc +++ b/paddle/operators/squared_l2_distance_op.cc @@ -22,7 +22,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of SquaredL2DistanceOp should not be null."); PADDLE_ENFORCE(ctx->HasInput("Y"), @@ -86,7 +86,7 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Gradient of Out should not be null"); auto out_dims = ctx->GetInputDim(framework::GradVarName("Out")); diff --git a/paddle/operators/sum_op.cc b/paddle/operators/sum_op.cc index c701ee8dde26f0fd50fae227d3f345df2d7a119d..ffb0cb92111bfb8490d35e4f5cfc9e405b0e3250 100644 --- a/paddle/operators/sum_op.cc +++ b/paddle/operators/sum_op.cc @@ -22,7 +22,7 @@ class SumOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null"); auto x_dims = ctx->GetInputsDim("X"); PADDLE_ENFORCE(ctx->HasOutput("Out"), diff --git a/paddle/operators/top_k_op.cc b/paddle/operators/top_k_op.cc index 5f22bf1df8720b60aba7cd75896d88cd1ad77635..c95481991229dd40e162094ccbf15de8ab627c8b 100644 --- a/paddle/operators/top_k_op.cc +++ b/paddle/operators/top_k_op.cc @@ -22,7 +22,7 @@ class TopkOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase *ctx) const override { + void InferShape(framework::InferShapeContext *ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of TopkOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), diff --git a/paddle/operators/transpose_op.cc b/paddle/operators/transpose_op.cc index 0672f9342dac00ecc3f358450a9a203327cbb51f..1101bbe3efe3d313400663ade561063da4e34dfd 100644 --- a/paddle/operators/transpose_op.cc +++ b/paddle/operators/transpose_op.cc @@ -24,7 +24,7 @@ class TransposeOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null"); auto x_dims = ctx->GetInputDim("X"); @@ -93,7 +93,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null"); PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")), "Input(Out@GRAD) should not be null"); diff --git a/paddle/operators/uniform_random_op.cc b/paddle/operators/uniform_random_op.cc index 97b1d0bed4595cb750e4d2122f294f10edfbe0ff..e330877fc4283b796dcb5c5d745881884ae491ae 100644 --- a/paddle/operators/uniform_random_op.cc +++ b/paddle/operators/uniform_random_op.cc @@ -47,7 +47,7 @@ class UniformRandomOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; protected: - void InferShape(framework::InferShapeContextBase* ctx) const override { + void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) of UniformRandomOp should not be null.");