diff --git a/paddle/fluid/framework/ngraph_operator.cc b/paddle/fluid/framework/ngraph_operator.cc index e9ff0513557725d068c3d3082f15c1f773e6624a..8878917da1683606b3e5602e4981c10b88a7d735 100644 --- a/paddle/fluid/framework/ngraph_operator.cc +++ b/paddle/fluid/framework/ngraph_operator.cc @@ -278,39 +278,22 @@ std::shared_ptr NgraphOperator::backend_ = ngraph::runtime::Backend::create("CPU"); void NgraphOperator::GetNgInputShape(std::shared_ptr op) { - RuntimeInferShapeContext infer_shape_ctx(*op, scope_); - std::shared_ptr op_k = - std::dynamic_pointer_cast(op); - op_k->InferShape(&infer_shape_ctx); - + op->RunInferShape(scope_, place_); for (auto& var_name_item : op->Inputs()) { - std::vector vshape; - auto& var_prm_name = var_name_item.first; - auto var_name_size = var_name_item.second.size(); - if (var_name_size == 1) { - auto dim = infer_shape_ctx.GetInputDim(var_prm_name); - vshape.push_back(Ddim2Shape(dim)); - } else if (var_name_item.second.size() > 1) { - auto vdim = infer_shape_ctx.GetInputsDim(var_prm_name); - PADDLE_ENFORCE_EQ(vdim.size(), var_name_item.second.size(), - "Need dim info for each var"); - for (auto& dim : vdim) { - vshape.push_back(Ddim2Shape(dim)); - } - } else { - // 0 size : conv2d Bias - } - - for (size_t i = 0; i < var_name_item.second.size(); ++i) { - auto var_name = var_name_item.second.at(i); - if (std::find(var_in_.begin(), var_in_.end(), var_name) != - var_in_.end()) { - if (var_node_map_->find(var_name) == var_node_map_->end()) { - auto ng_type = var_type_map_.at(var_name); - auto prm = std::make_shared( - ng_type, vshape.at(i), true); - (*var_node_map_)[var_name] = prm; - (*var_in_node_map_)[var_name] = prm; + for (auto& var_name : var_name_item.second) { + auto* var = scope_.FindVar(var_name); + if (var && VarIsTensor(*var)) { + auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var); + auto sp = Ddim2Shape(tensor_pd->dims()); + if (std::find(var_in_.begin(), var_in_.end(), var_name) != + var_in_.end()) { + if (var_node_map_->find(var_name) == var_node_map_->end()) { + auto ng_type = var_type_map_.at(var_name); + auto prm = + std::make_shared(ng_type, sp, true); + (*var_node_map_)[var_name] = prm; + (*var_in_node_map_)[var_name] = prm; + } } } } diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 8bfdf3891203823826fd5bf919c176011f22213c..a816aa94c030024195d4647a49cee8bdda12c568 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -355,7 +355,7 @@ void OperatorBase::GenerateTemporaryNames() { } } -static bool VarIsTensor(const Variable& var) { +bool VarIsTensor(const Variable& var) { return var.IsType() || var.IsType(); } @@ -695,6 +695,12 @@ static void CheckTensorNANOrInf(const std::string& name, "Tensor %s contains NAN", name); } +void OperatorWithKernel::RunInferShape(const Scope& scope, + const platform::Place& place) const { + RuntimeInferShapeContext infer_shape_ctx(*this, scope); + this->InferShape(&infer_shape_ctx); +} + void OperatorWithKernel::RunImpl(const Scope& scope, const platform::Place& place) const { RuntimeInferShapeContext infer_shape_ctx(*this, scope); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 5bd68f9ac2e1b30bc6ce3094960bb89842b99e01..fcf889f3db1da0359258a164fcb90013a3357284 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -64,6 +64,7 @@ inline std::string GradVarName(const std::string& var_name) { } proto::VarType::Type GetDataTypeOfVar(const Variable* var); +bool VarIsTensor(const Variable& var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); @@ -128,6 +129,8 @@ class OperatorBase { virtual std::vector OutputVars(bool has_intermediate) const; void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } + virtual void RunInferShape(const Scope& scope, + const platform::Place& place) const {} protected: std::string type_; @@ -348,6 +351,9 @@ class OperatorWithKernel : public OperatorBase { OpInfoMap::Instance().Get(Type()).infer_shape_(ctx); } + void RunInferShape(const Scope& scope, + const platform::Place& place) const override; + protected: virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetKernelTypeForVar(