提交 c0a34e1c 编写于 作者: Q qiaolongfei

rename InferShapeContextBase to InferShapeContext

上级 a0767228
......@@ -309,7 +309,7 @@ template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
const std::string& name) const;
class CompileTimeInferShapeContext : public InferShapeContextBase {
class CompileTimeInferShapeContext : public InferShapeContext {
public:
CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block)
: op_(op), block_(block) {}
......@@ -405,7 +405,7 @@ class CompileTimeInferShapeContext : public InferShapeContextBase {
const BlockDescBind& block_;
};
class RuntimeInferShapeContext : public InferShapeContextBase {
class RuntimeInferShapeContext : public InferShapeContext {
public:
RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
: op_(op), scope_(scope) {}
......@@ -603,7 +603,7 @@ class OperatorWithKernel : public OperatorBase {
});
}
virtual void InferShape(InferShapeContextBase* ctx) const = 0;
virtual void InferShape(InferShapeContext* ctx) const = 0;
protected:
// indicate kernel DataType by input data. Defaultly all input data must be
......
......@@ -113,7 +113,7 @@ class OpWithKernelTest : public OperatorWithKernel {
using OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {}
void InferShape(framework::InferShapeContext* ctx) const override {}
DataType IndicateDataType(const ExecutionContext& ctx) const override {
return DataType::FP32;
}
......
......@@ -20,11 +20,11 @@ namespace paddle {
namespace framework {
// TODO(longfei): Once after both CompileTimeInferShapeContext and
// RuntimeInferShapeContext get merged, we can rename InferShapeContextBase into
// RuntimeInferShapeContext get merged, we can rename InferShapeContext into
// InferShapeContext so to replace the current InferShapeContext.
class InferShapeContextBase {
class InferShapeContext {
public:
virtual ~InferShapeContextBase() {}
virtual ~InferShapeContext() {}
virtual bool HasInput(const std::string &name) const = 0;
virtual bool HasOutput(const std::string &name) const = 0;
......
......@@ -22,7 +22,7 @@ class AccuracyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Inference"),
"Input(Inference) of AccuracyOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Label"),
......
......@@ -22,7 +22,7 @@ class ActivationOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim("Y", ctx->GetInputDim("X"));
ctx->ShareLoD("X", /*->*/ "Y");
}
......@@ -33,7 +33,7 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("Y"));
}
};
......
......@@ -22,7 +22,7 @@ class AdadeltaOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdadeltaOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
......
......@@ -22,7 +22,7 @@ class AdagradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of AdagradOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
......
......@@ -22,7 +22,7 @@ class ClipOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ClipOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -61,7 +61,7 @@ class ClipOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
......
......@@ -24,7 +24,7 @@ class ConcatOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_GE(ctx->Inputs("X").size(), 1UL,
"Inputs(X) of ConcatOp should be empty.")
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -83,7 +83,7 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
ctx->SetOutputsDim(framework::GradVarName("X"), ctx->GetInputsDim("X"));
}
};
......
......@@ -27,7 +27,7 @@ class Conv2DOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of Conv2DOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Filter"),
......@@ -106,7 +106,7 @@ class Conv2DOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("Input");
auto filter_dims = ctx->GetInputDim("Filter");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
......
......@@ -24,7 +24,7 @@ class CosSimOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
// notnull check
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of CosSimOp should not be null.");
......@@ -98,7 +98,7 @@ class CosSimOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
// notnull check
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) must not be null.");
......
......@@ -25,7 +25,7 @@ class CropOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of CropOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -115,7 +115,7 @@ class CropOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
......
......@@ -22,7 +22,7 @@ class CrossEntropyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
......@@ -60,7 +60,7 @@ class CrossEntropyGradientOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
......
......@@ -24,7 +24,7 @@ class DropoutOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE_GE(ctx->Attrs().Get<float>("dropout_prob"), 0);
PADDLE_ENFORCE_LE(ctx->Attrs().Get<float>("dropout_prob"), 1);
......@@ -70,7 +70,7 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), 1,
"GradOp is only callable when is_training is true");
......
......@@ -25,7 +25,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
protected:
using Tensor = framework::Tensor;
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of elementwise op should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"),
......@@ -106,7 +106,7 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
using Tensor = framework::Tensor;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
......
......@@ -22,7 +22,7 @@ class FillZerosLikeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of FillZerosLikeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
......
......@@ -23,7 +23,7 @@ class GatherOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of GatherOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Index"),
......@@ -51,7 +51,7 @@ class GatherGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
......
......@@ -43,7 +43,7 @@ class GaussianRandomOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of GaussianRandomOp should not be null.");
auto dims = ctx->Attrs().Get<std::vector<int>>("dims");
......
......@@ -22,7 +22,7 @@ class LookupTableOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("W"),
"Input(W) of LookupTableOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Ids"),
......@@ -70,7 +70,7 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
auto table_dims = ctx->GetInputDim("W");
ctx->SetOutputDim(framework::GradVarName("W"), table_dims);
}
......
......@@ -22,7 +22,7 @@ class LstmUnitOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of LSTM should not be null.");
PADDLE_ENFORCE(ctx->HasInput("C_prev"),
"Input(C_prev) of LSTM should not be null.");
......@@ -77,7 +77,7 @@ class LstmUnitGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("C")),
"Input(C@GRAD) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("H")),
......
......@@ -22,7 +22,7 @@ class MeanOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MeanOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -47,7 +47,7 @@ class MeanGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
......
......@@ -26,7 +26,7 @@ class MinusOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MinusOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
......
......@@ -22,7 +22,7 @@ class ModifiedHuberLossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
......@@ -74,7 +74,7 @@ class ModifiedHuberLossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("IntermediateVal"),
......
......@@ -24,7 +24,7 @@ class MulOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MulOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of MulOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -97,7 +97,7 @@ class MulOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
......
......@@ -24,7 +24,7 @@ class MultiplexOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ids"), "Input(Ids) shouldn't be null.");
PADDLE_ENFORCE(!ctx->Inputs("X").empty(),
"MultiInput(X) shouldn't be empty.");
......@@ -90,7 +90,7 @@ class MultiplexGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(!ctx->Inputs("X").empty(), "Input(X) should not be null.");
PADDLE_ENFORCE(!ctx->Outputs(framework::GradVarName("X")).empty(),
"Output(X@Grad) should not be null.");
......
......@@ -24,7 +24,7 @@ class PadOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of PadOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of PadOp should not be null.");
......@@ -98,7 +98,7 @@ class PadOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
......
......@@ -27,7 +27,7 @@ class PoolOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -74,7 +74,7 @@ class PoolOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
......
......@@ -26,7 +26,7 @@ class PReluOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null");
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
......@@ -63,7 +63,7 @@ class PReluGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
......
......@@ -25,7 +25,7 @@ class RankLossOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
// input check
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null");
PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null");
......@@ -90,7 +90,7 @@ class RankLossGradOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Left"), "Input(Left) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput("Right"), "Input(Right) shouldn't be null.");
......
......@@ -24,7 +24,7 @@ class ReduceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReduceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -58,7 +58,7 @@ class ReduceGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null.");
......
......@@ -26,7 +26,7 @@ class ReshapeOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
// input check
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReshapeOp should not be null.");
......@@ -94,7 +94,7 @@ class ReshapeGradOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
......
......@@ -22,7 +22,7 @@ class RmspropOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of RmspropOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("MeanSquare"),
......
......@@ -26,7 +26,7 @@ class ScaleOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ScaleOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......
......@@ -23,7 +23,7 @@ class ScatterOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Ref"),
"Input(Ref) of ScatterOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Index"),
......@@ -60,7 +60,7 @@ class ScatterGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->SetOutputDim(framework::GradVarName("Updates"),
ctx->GetInputDim("Updates"));
ctx->SetOutputDim(framework::GradVarName("Ref"), ctx->GetInputDim("Ref"));
......
......@@ -22,7 +22,7 @@ class SequencePoolOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequencePoolOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -74,7 +74,7 @@ class SequencePoolGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of Out should not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"), "The input X should not be null.");
......
......@@ -22,7 +22,7 @@ class SequenceSoftmaxOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SequenceSoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......@@ -67,7 +67,7 @@ class SequenceSoftmaxGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Out"),
"Input(Out) of SequenceSoftmaxGradOp should not be null.");
PADDLE_ENFORCE(
......
......@@ -22,7 +22,7 @@ class SGDOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Param"),
"Input(Param) of SGDOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Grad"),
......
......@@ -24,7 +24,7 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Labels"),
"Input(Labels) should be not null.");
......@@ -53,7 +53,7 @@ class SigmoidCrossEntropyWithLogitsGradOp
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Labels"),
"Input(Labels) should be not null.");
......
......@@ -22,7 +22,7 @@ class SmoothL1LossOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "X must be initialized.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Y must be initialized.");
......@@ -94,7 +94,7 @@ class SmoothL1LossGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
auto in_dims = ctx->GetInputDim("X");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
......
......@@ -22,7 +22,7 @@ class SoftmaxOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SoftmaxOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"),
......@@ -69,7 +69,7 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) should be not null.");
......
......@@ -83,7 +83,7 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Logits"),
"Input(Logits) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
......@@ -128,7 +128,7 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Loss")),
"Input(Loss@Grad) should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Softmax"),
......
......@@ -24,7 +24,7 @@ class SplitOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SplitOp should not be null.");
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
......
......@@ -22,7 +22,7 @@ class SquaredL2DistanceOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SquaredL2DistanceOp should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"),
......@@ -86,7 +86,7 @@ class SquaredL2DistanceGradOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Gradient of Out should not be null");
auto out_dims = ctx->GetInputDim(framework::GradVarName("Out"));
......
......@@ -22,7 +22,7 @@ class SumOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInputs("X"), "Inputs(X) should not be null");
auto x_dims = ctx->GetInputsDim("X");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......
......@@ -22,7 +22,7 @@ class TopkOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase *ctx) const override {
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of TopkOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
......
......@@ -24,7 +24,7 @@ class TransposeOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
auto x_dims = ctx->GetInputDim("X");
......@@ -93,7 +93,7 @@ class TransposeOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) should not be null");
......
......@@ -47,7 +47,7 @@ class UniformRandomOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContextBase* ctx) const override {
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UniformRandomOp should not be null.");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册