提交 caf4b937 编写于 作者: B baojun-nervana

Added RunInferShape

test=develop
上级 1d19eb2b
...@@ -278,39 +278,22 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphOperator::backend_ = ...@@ -278,39 +278,22 @@ std::shared_ptr<ngraph::runtime::Backend> NgraphOperator::backend_ =
ngraph::runtime::Backend::create("CPU"); ngraph::runtime::Backend::create("CPU");
void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) { void NgraphOperator::GetNgInputShape(std::shared_ptr<OperatorBase> op) {
RuntimeInferShapeContext infer_shape_ctx(*op, scope_); op->RunInferShape(scope_, place_);
std::shared_ptr<OperatorWithKernel> op_k =
std::dynamic_pointer_cast<OperatorWithKernel>(op);
op_k->InferShape(&infer_shape_ctx);
for (auto& var_name_item : op->Inputs()) { for (auto& var_name_item : op->Inputs()) {
std::vector<ngraph::Shape> vshape; for (auto& var_name : var_name_item.second) {
auto& var_prm_name = var_name_item.first; auto* var = scope_.FindVar(var_name);
auto var_name_size = var_name_item.second.size(); if (var && VarIsTensor(*var)) {
if (var_name_size == 1) { auto* tensor_pd = GetLoDTensorOrSelectedRowsValueFromVar(*var);
auto dim = infer_shape_ctx.GetInputDim(var_prm_name); auto sp = Ddim2Shape(tensor_pd->dims());
vshape.push_back(Ddim2Shape(dim)); if (std::find(var_in_.begin(), var_in_.end(), var_name) !=
} else if (var_name_item.second.size() > 1) { var_in_.end()) {
auto vdim = infer_shape_ctx.GetInputsDim(var_prm_name); if (var_node_map_->find(var_name) == var_node_map_->end()) {
PADDLE_ENFORCE_EQ(vdim.size(), var_name_item.second.size(), auto ng_type = var_type_map_.at(var_name);
"Need dim info for each var"); auto prm =
for (auto& dim : vdim) { std::make_shared<ngraph::op::Parameter>(ng_type, sp, true);
vshape.push_back(Ddim2Shape(dim)); (*var_node_map_)[var_name] = prm;
} (*var_in_node_map_)[var_name] = prm;
} else { }
// 0 size : conv2d Bias
}
for (size_t i = 0; i < var_name_item.second.size(); ++i) {
auto var_name = var_name_item.second.at(i);
if (std::find(var_in_.begin(), var_in_.end(), var_name) !=
var_in_.end()) {
if (var_node_map_->find(var_name) == var_node_map_->end()) {
auto ng_type = var_type_map_.at(var_name);
auto prm = std::make_shared<ngraph::op::Parameter>(
ng_type, vshape.at(i), true);
(*var_node_map_)[var_name] = prm;
(*var_in_node_map_)[var_name] = prm;
} }
} }
} }
......
...@@ -355,7 +355,7 @@ void OperatorBase::GenerateTemporaryNames() { ...@@ -355,7 +355,7 @@ void OperatorBase::GenerateTemporaryNames() {
} }
} }
static bool VarIsTensor(const Variable& var) { bool VarIsTensor(const Variable& var) {
return var.IsType<LoDTensor>() || var.IsType<SelectedRows>(); return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
} }
...@@ -695,6 +695,12 @@ static void CheckTensorNANOrInf(const std::string& name, ...@@ -695,6 +695,12 @@ static void CheckTensorNANOrInf(const std::string& name,
"Tensor %s contains NAN", name); "Tensor %s contains NAN", name);
} }
void OperatorWithKernel::RunInferShape(const Scope& scope,
const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope);
this->InferShape(&infer_shape_ctx);
}
void OperatorWithKernel::RunImpl(const Scope& scope, void OperatorWithKernel::RunImpl(const Scope& scope,
const platform::Place& place) const { const platform::Place& place) const {
RuntimeInferShapeContext infer_shape_ctx(*this, scope); RuntimeInferShapeContext infer_shape_ctx(*this, scope);
......
...@@ -64,6 +64,7 @@ inline std::string GradVarName(const std::string& var_name) { ...@@ -64,6 +64,7 @@ inline std::string GradVarName(const std::string& var_name) {
} }
proto::VarType::Type GetDataTypeOfVar(const Variable* var); proto::VarType::Type GetDataTypeOfVar(const Variable* var);
bool VarIsTensor(const Variable& var);
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var); const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var); Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
...@@ -128,6 +129,8 @@ class OperatorBase { ...@@ -128,6 +129,8 @@ class OperatorBase {
virtual std::vector<std::string> OutputVars(bool has_intermediate) const; virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; } void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
virtual void RunInferShape(const Scope& scope,
const platform::Place& place) const {}
protected: protected:
std::string type_; std::string type_;
...@@ -348,6 +351,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -348,6 +351,9 @@ class OperatorWithKernel : public OperatorBase {
OpInfoMap::Instance().Get(Type()).infer_shape_(ctx); OpInfoMap::Instance().Get(Type()).infer_shape_(ctx);
} }
void RunInferShape(const Scope& scope,
const platform::Place& place) const override;
protected: protected:
virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const; virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
virtual OpKernelType GetKernelTypeForVar( virtual OpKernelType GetKernelTypeForVar(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册