提交 c91de280 编写于 作者: Q Qiao Longfei 提交者: Yu Yang

CompileTime InferShape should find var recursively in stack of blocks (#4998)

* recursive find var in BlockDesc

* add HasVarRecursive and FindVarRecursive to BlockDesc

* fix FindVarRecursive
上级 54ffafa1
......@@ -41,6 +41,19 @@ bool BlockDescBind::HasVar(const std::string &name) const {
return vars_.find(name) != vars_.end();
}
VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const {
auto it = vars_.find(name);
if (it == vars_.end()) {
return Parent() == kNoneBlockIndex ? nullptr
: ParentBlock()->FindVarRecursive(name);
}
return it->second.get();
}
bool BlockDescBind::HasVarRecursive(const std::string &name) const {
return FindVarRecursive(name) != nullptr;
}
std::vector<VarDescBind *> BlockDescBind::AllVars() const {
std::vector<VarDescBind *> res;
for (const auto &p : vars_) {
......@@ -97,7 +110,7 @@ void BlockDescBind::Flush() {
}
BlockDescBind *BlockDescBind::ParentBlock() const {
if (this->desc_->parent_idx() == -1) {
if (this->desc_->parent_idx() == kNoneBlockIndex) {
return nullptr;
}
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include <vector>
#include "paddle/framework/op_desc.h"
#include "paddle/framework/proto_desc.h"
#include "paddle/framework/var_desc.h"
#include "paddle/platform/macros.h"
......@@ -56,6 +57,10 @@ class BlockDescBind {
bool HasVar(const std::string &var_name) const;
VarDescBind *FindVarRecursive(const std::string &name_bytes) const;
bool HasVarRecursive(const std::string &var_name) const;
std::set<std::string> LocalVarNames() const {
std::set<std::string> var_names;
for (auto &var : vars_) {
......
......@@ -334,7 +334,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
"Input(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVar(input_names[0]);
return block_.HasVarRecursive(input_names[0]);
}
bool HasOutput(const std::string& name) const override {
......@@ -347,7 +347,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
"Output(%s) should have only one value, "
"but it have %d now",
name, length);
return block_.HasVar(output_names[0]);
return block_.HasVarRecursive(output_names[0]);
}
bool HasInputs(const std::string& name) const override {
......@@ -356,7 +356,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
return false;
}
for (auto& input : input_names) {
if (!block_.HasVar(input)) return false;
if (!block_.HasVarRecursive(input)) return false;
}
return true;
}
......@@ -367,7 +367,7 @@ class CompileTimeInferShapeContext : public InferShapeContext {
return false;
}
for (auto& output : output_names) {
if (!block_.HasVar(output)) return false;
if (!block_.HasVarRecursive(output)) return false;
}
return true;
}
......@@ -414,11 +414,11 @@ class CompileTimeInferShapeContext : public InferShapeContext {
private:
DDim GetDim(const std::string& name) const override {
return framework::make_ddim(block_.FindVar(name)->Shape());
return framework::make_ddim(block_.FindVarRecursive(name)->Shape());
}
void SetDim(const std::string& name, const DDim& dim) override {
block_.FindVar(name)->SetShape(framework::vectorize(dim));
block_.FindVarRecursive(name)->SetShape(framework::vectorize(dim));
}
const OpDescBind& op_;
......
......@@ -35,8 +35,8 @@ ProgramDesc *ProgramDescBind::Proto() {
ProgramDescBind::ProgramDescBind() {
auto *block = prog_.mutable_blocks()->Add();
block->set_idx(0);
block->set_parent_idx(-1);
block->set_idx(kRootBlockIndex);
block->set_parent_idx(kNoneBlockIndex);
blocks_.emplace_back(new BlockDescBind(this, block));
}
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <memory>
#include <vector>
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/proto_desc.h"
#include "paddle/platform/macros.h"
namespace paddle {
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
namespace paddle {
namespace framework {
// The Index of first Block in Program. also called root block.
constexpr int kRootBlockIndex = 0;
// The Parent Index of root Block, this block does not exist.
constexpr int kNoneBlockIndex = -1;
} // namespace framework
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册