diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index 99c4cf0da607cdb5d3282696fc7290187a3f4ed9..61bae1aba41d7630d40585b521688fbd6c069165 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -279,7 +279,7 @@ std::shared_ptr NgraphOperator::backend_ = ngraph::runtime::Backend::create("CPU"); void NgraphOperator::GetNgInputShape(std::shared_ptr op) { - op->RunInferShape(scope_, place_); + op->RuntimeInferShape(scope_, place_); for (auto& var_name_item : op->Inputs()) { for (auto& var_name : var_name_item.second) { auto* var = scope_.FindVar(var_name); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index a816aa94c030024195d4647a49cee8bdda12c568..f3d225df69c5f6c320d60897b959f66e86080205 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -695,8 +695,8 @@ static void CheckTensorNANOrInf(const std::string& name, "Tensor %s contains NAN", name); } -void OperatorWithKernel::RunInferShape(const Scope& scope, - const platform::Place& place) const { +void OperatorWithKernel::RuntimeInferShape(const Scope& scope, + const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); this->InferShape(&infer_shape_ctx); } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index fcf889f3db1da0359258a164fcb90013a3357284..efc9a1b6f5a38f73fec2771d3bc6c3141f642baf 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -129,8 +129,8 @@ class OperatorBase { virtual std::vector OutputVars(bool has_intermediate) const; void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } - virtual void RunInferShape(const Scope& scope, - const platform::Place& place) const {} + virtual void RuntimeInferShape(const Scope& scope, + const platform::Place& place) const {} protected: std::string type_; @@ -351,8 +351,8 @@ class OperatorWithKernel : public OperatorBase { OpInfoMap::Instance().Get(Type()).infer_shape_(ctx); } - void RunInferShape(const Scope& scope, - const platform::Place& place) const override; + void RuntimeInferShape(const Scope& scope, + const platform::Place& place) const override; protected: virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;