diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 21d4fdaf0680036a484ee4258e47c6c8854967c3..251e340e6ddcc17ba16bdcab63f2a8c907122eab 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -41,6 +41,19 @@ bool BlockDescBind::HasVar(const std::string &name) const { return vars_.find(name) != vars_.end(); } +VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const { + auto it = vars_.find(name); + if (it == vars_.end()) { + return Parent() == kNoneBlockIndex ? nullptr + : ParentBlock()->FindVarRecursive(name); + } + return it->second.get(); +} + +bool BlockDescBind::HasVarRecursive(const std::string &name) const { + return FindVarRecursive(name) != nullptr; +} + std::vector BlockDescBind::AllVars() const { std::vector res; for (const auto &p : vars_) { @@ -97,7 +110,7 @@ void BlockDescBind::Flush() { } BlockDescBind *BlockDescBind::ParentBlock() const { - if (this->desc_->parent_idx() == -1) { + if (this->desc_->parent_idx() == kNoneBlockIndex) { return nullptr; } return prog_->Block(static_cast(this->desc_->parent_idx())); diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 7d1d33f6860aa90518abb379a5e9964d6029c6fa..c685050850dc25f346df49b5ce1d897974870460 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -21,6 +21,7 @@ limitations under the License. */ #include #include "paddle/framework/op_desc.h" +#include "paddle/framework/proto_desc.h" #include "paddle/framework/var_desc.h" #include "paddle/platform/macros.h" @@ -56,6 +57,10 @@ class BlockDescBind { bool HasVar(const std::string &var_name) const; + VarDescBind *FindVarRecursive(const std::string &name_bytes) const; + + bool HasVarRecursive(const std::string &var_name) const; + std::set LocalVarNames() const { std::set var_names; for (auto &var : vars_) { diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 79a452b616b0109f8d137669cd111b16d5839287..0d0304ac9e13089ef533b0a47f0ec989c8fd7078 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -334,7 +334,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { "Input(%s) should have only one value, " "but it have %d now", name, length); - return block_.HasVar(input_names[0]); + return block_.HasVarRecursive(input_names[0]); } bool HasOutput(const std::string& name) const override { @@ -347,7 +347,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { "Output(%s) should have only one value, " "but it have %d now", name, length); - return block_.HasVar(output_names[0]); + return block_.HasVarRecursive(output_names[0]); } bool HasInputs(const std::string& name) const override { @@ -356,7 +356,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { return false; } for (auto& input : input_names) { - if (!block_.HasVar(input)) return false; + if (!block_.HasVarRecursive(input)) return false; } return true; } @@ -367,7 +367,7 @@ class CompileTimeInferShapeContext : public InferShapeContext { return false; } for (auto& output : output_names) { - if (!block_.HasVar(output)) return false; + if (!block_.HasVarRecursive(output)) return false; } return true; } @@ -414,11 +414,11 @@ class CompileTimeInferShapeContext : public InferShapeContext { private: DDim GetDim(const std::string& name) const override { - return framework::make_ddim(block_.FindVar(name)->Shape()); + return framework::make_ddim(block_.FindVarRecursive(name)->Shape()); } void SetDim(const std::string& name, const DDim& dim) override { - block_.FindVar(name)->SetShape(framework::vectorize(dim)); + block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim)); } const OpDescBind& op_; diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc index e2349cefe09a6c1e0b11f77775426fe5c7594c9d..8e99bba81117c9cc50227122527d6ab9a421c251 100644 --- a/paddle/framework/program_desc.cc +++ b/paddle/framework/program_desc.cc @@ -35,8 +35,8 @@ ProgramDesc *ProgramDescBind::Proto() { ProgramDescBind::ProgramDescBind() { auto *block = prog_.mutable_blocks()->Add(); - block->set_idx(0); - block->set_parent_idx(-1); + block->set_idx(kRootBlockIndex); + block->set_parent_idx(kNoneBlockIndex); blocks_.emplace_back(new BlockDescBind(this, block)); } diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index 20cc1a2325ffd6f8ef52879a749f106c268376d4..dc4cd7cc735b5e4e3466d9b82dc5eb8647c80ef9 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include "paddle/framework/framework.pb.h" +#include "paddle/framework/proto_desc.h" #include "paddle/platform/macros.h" namespace paddle { diff --git a/paddle/framework/proto_desc.h b/paddle/framework/proto_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..fa01224fefce50eb3688ff407f0a7c948c5b7cfc --- /dev/null +++ b/paddle/framework/proto_desc.h @@ -0,0 +1,26 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +namespace paddle { +namespace framework { + +// The Index of first Block in Program. also called root block. +constexpr int kRootBlockIndex = 0; +// The Parent Index of root Block, this block does not exist. +constexpr int kNoneBlockIndex = -1; + +} // namespace framework +} // namespace paddle