From c91de280d783d531792e8a458cc50342eb405f59 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sun, 22 Oct 2017 10:54:42 -0700 Subject: [PATCH] CompileTime InferShape should find var recursively in stack of blocks (#4998) * recursive find var in BlockDesc * add HasVarRecursive and FindVarRecursive to BlockDesc * fix FindVarRecursive --- paddle/framework/block_desc.cc | 15 ++++++++++++++- paddle/framework/block_desc.h | 5 +++++ paddle/framework/operator.h | 12 ++++++------ paddle/framework/program_desc.cc | 4 ++-- paddle/framework/program_desc.h | 1 + paddle/framework/proto_desc.h | 26 ++++++++++++++++++++++++++ 6 files changed, 54 insertions(+), 9 deletions(-) create mode 100644 paddle/framework/proto_desc.h diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 21d4fdaf0..251e340e6 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -41,6 +41,19 @@ bool BlockDescBind::HasVar(const std::string &name) const { return vars_.find(name) != vars_.end(); } +VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const { + auto it = vars_.find(name); + if (it == vars_.end()) { + return Parent() == kNoneBlockIndex ? nullptr + : ParentBlock()->FindVarRecursive(name); + } + return it->second.get(); +} + +bool BlockDescBind::HasVarRecursive(const std::string &name) const { + return FindVarRecursive(name) != nullptr; +} + std::vector BlockDescBind::AllVars() const { std::vector res; for (const auto &p : vars_) { @@ -97,7 +110,7 @@ void BlockDescBind::Flush() { } BlockDescBind *BlockDescBind::ParentBlock() const { - if (this->desc_->parent_idx() == -1) { + if (this->desc_->parent_idx() == kNoneBlockIndex) { return nullptr; } return prog_->Block(static_cast(this->desc_->parent_idx())); diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 7d1d33f68..c68505085 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include "paddle/framework/op_desc.h" +#include "paddle/framework/proto_desc.h" #include "paddle/framework/var_desc.h" #include "paddle/platform/macros.h" @@ -56,6 +57,10 @@ class BlockDescBind { bool HasVar(const std::string &var_name) const; + VarDescBind *FindVarRecursive(const std::string &name_bytes) const; + + bool HasVarRecursive(const std::string &var_name) const; + std::set LocalVarNames() const { std::set var_names; for (auto &var : vars_) { diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 79a452b61..0d0304ac9 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -334,7 +334,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { "Input(%s) should have only one value, " "but it have %d now", name, length); - return block_.HasVar(input_names[0]); + return block_.HasVarRecursive(input_names[0]); } bool HasOutput(const std::string& name) const override { @@ -347,7 +347,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { "Output(%s) should have only one value, " "but it have %d now", name, length); - return block_.HasVar(output_names[0]); + return block_.HasVarRecursive(output_names[0]); } bool HasInputs(const std::string& name) const override { @@ -356,7 +356,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { return false; } for (auto& input : input_names) { - if (!block_.HasVar(input)) return false; + if (!block_.HasVarRecursive(input)) return false; } return true; } @@ -367,7 +367,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { return false; } for (auto& output : output_names) { - if (!block_.HasVar(output)) return false; + if (!block_.HasVarRecursive(output)) return false; } return true; } @@ -414,11 +414,11 @@ class CompileTimeInferShapeContext : public InferShapeContext { private: DDim GetDim(const std::string& name) const override { - return framework::make_ddim(block_.FindVar(name)->Shape()); + return framework::make_ddim(block_.FindVarRecursive(name)->Shape()); } void SetDim(const std::string& name, const DDim& dim) override { - block_.FindVar(name)->SetShape(framework::vectorize(dim)); + block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim)); } const OpDescBind& op_; diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc index e2349cefe..8e99bba81 100644 --- a/paddle/framework/program_desc.cc +++ b/paddle/framework/program_desc.cc @@ -35,8 +35,8 @@ ProgramDesc *ProgramDescBind::Proto() { ProgramDescBind::ProgramDescBind() { auto *block = prog_.mutable_blocks()->Add(); - block->set_idx(0); - block->set_parent_idx(-1); + block->set_idx(kRootBlockIndex); + block->set_parent_idx(kNoneBlockIndex); blocks_.emplace_back(new BlockDescBind(this, block)); } diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index 20cc1a232..dc4cd7cc7 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/framework/framework.pb.h" +#include "paddle/framework/proto_desc.h" #include "paddle/platform/macros.h" namespace paddle { diff --git a/paddle/framework/proto_desc.h b/paddle/framework/proto_desc.h new file mode 100644 index 000000000..fa01224fe --- /dev/null +++ b/paddle/framework/proto_desc.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace framework { + +// The Index of first Block in Program. also called root block. +constexpr int kRootBlockIndex = 0; +// The Parent Index of root Block, this block does not exist. +constexpr int kNoneBlockIndex = -1; + +} // namespace framework +} // namespace paddle -- GitLab