提交 e6bd53be 编写于 作者: B baojun-nervana

Named to RuntimeInferShape

test=develop
上级 24e70920
...@@ -279,7 +279,7 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphOperator::backend_ = ...@@ -279,7 +279,7 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphOperator::backend_ =
ngraph::runtime::Backend::create("CPU"); ngraph::runtime::Backend::create("CPU");
void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) { void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
op->RunInferShape(scope_, place_); op->RuntimeInferShape(scope_, place_);
for (auto& var_name_item : op->Inputs()) { for (auto& var_name_item : op->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
auto* var = scope_.FindVar(var_name); auto* var = scope_.FindVar(var_name);
......
...@@ -695,8 +695,8 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -695,8 +695,8 @@ static void CheckTensorNANOrInf(const std::string& name,
"Tensor %s contains NAN", name); "Tensor %s contains NAN", name);
} }
void OperatorWithKernel::RunInferShape(const Scope& scope, void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx); this->InferShape(&infer_shape_ctx);
} }
......
...@@ -129,8 +129,8 @@ class OperatorBase { ...@@ -129,8 +129,8 @@ class OperatorBase {
virtual std::vector<std::string> OutputVars(bool has_intermediate) const; virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
virtual void RunInferShape(const Scope& scope, virtual void RuntimeInferShape(const Scope& scope,
const platform::Place& place) const {} const platform::Place& place) const {}
protected: protected:
std::string type_; std::string type_;
...@@ -351,8 +351,8 @@ class OperatorWithKernel : public OperatorBase { ...@@ -351,8 +351,8 @@ class OperatorWithKernel : public OperatorBase {
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx); OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
} }
void RunInferShape(const Scope& scope, void RuntimeInferShape(const Scope& scope,
const platform::Place& place) const override; const platform::Place& place) const override;
protected: protected:
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册