From d550380ed92ed9c762ff0248780a3c28dbf27416 Mon Sep 17 00:00:00 2001 From: qiaolongfei Date: Mon, 2 Oct 2017 16:52:11 -0700 Subject: [PATCH] add CompileTimeInferShapeContext --- paddle/framework/CMakeLists.txt | 2 +- paddle/framework/block_desc.cc | 5 ++ paddle/framework/block_desc.h | 2 + paddle/framework/operator.h | 84 +++++++++++++-------------------- 4 files changed, 41 insertions(+), 52 deletions(-) diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 5d394132b7f..a2efcdb55cf 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -23,7 +23,7 @@ cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto proto_desc) -cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope) +cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope proto_desc) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry) cc_library(grad_op_builder SRCS grad_op_builder.cc DEPS operator proto_desc) diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc index 9570aedfdda..670533a3fe3 100644 --- a/paddle/framework/block_desc.cc +++ b/paddle/framework/block_desc.cc @@ -34,6 +34,11 @@ VarDescBind *BlockDescBind::Var(const std::string &name) const { return it->second.get(); } +bool BlockDescBind::HasVar(const std::string &name) const { + auto it = vars_.find(name); + return it != vars_.end(); +} + std::vector BlockDescBind::AllVars() const { std::vector res; for (const auto &p : vars_) { diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h index 1a1135bab44..41cf1dc385a 100644 --- a/paddle/framework/block_desc.h +++ b/paddle/framework/block_desc.h @@ -45,6 +45,8 @@ class BlockDescBind { VarDescBind *Var(const std::string &name_bytes) const; + bool HasVar(const std::string &var_name) const; + std::vector AllVars() const; BlockDescBind *ParentBlock() const; diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index f807909650f..2874237b9c3 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -319,100 +319,82 @@ class ExecutionContext : public InferShapeContext { class CompileTimeInferShapeContext : public InferShapeContextBase { public: - CompileTimeInferShapeContext(const OperatorBase& op, const Scope& scope) - : op_(op), scope_(scope) {} + CompileTimeInferShapeContext(const OpDescBind& op, const BlockDescBind& block) + : op_(op), block_(block) {} bool HasInput(const std::string& name) const { - auto ipt = op_.Input(name); - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; + const std::vector& input_names = op_.Input(name); + PADDLE_ENFORCE_EQ(input_names.size(), 1UL, "Inputs(%s) length is not 1", + name); + return block_.HasVar(input_names[0]); } bool HasOutput(const std::string& name) const { - auto ipt = op_.Output(name); - auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt); - return var != nullptr; + const std::vector& output_names = op_.Output(name); + PADDLE_ENFORCE_EQ(output_names.size(), 1UL, "Outputs(%s) length is not 1", + name); + return block_.HasVar(output_names[0]); } bool HasInputs(const std::string& name) const { - auto inputs = op_.Inputs(name); - if (inputs.size() == 0UL) { - return false; - } - for (auto& input : inputs) { - if (scope_.FindVar(input) == nullptr) { - return false; - } + const std::vector& input_names = op_.Input(name); + PADDLE_ENFORCE_GT(input_names.size(), 0UL, "Inputs(%s) length is 0", name); + for (auto& input : input_names) { + if (!block_.HasVar(input)) return false; } return true; } bool HasOutputs(const std::string& name) const { - auto outputs = op_.Outputs(name); - if (outputs.size() == 0UL) { - return false; - } - for (auto& output : outputs) { - if (scope_.FindVar(output) == nullptr) { - return false; - } + const std::vector& output_names = op_.Output(name); + PADDLE_ENFORCE_GT(output_names.size(), 0UL, "Inputs(%s) length is 0", name); + for (auto& output : output_names) { + if (!block_.HasVar(name)) return false; } return true; } DDim GetInputDim(const std::string& name) const { - return GetDim(op_.Input(name)); + std::vector ddims = GetInputsDim(name); + PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Inputs(%s) length is not 1", name); + return ddims[0]; } void SetInputDim(const std::string& name, const DDim& dim) { - SetDim(op_.Input(name), dim); + SetInputsDim(name, {dim}); } DDim GetOutputDim(const std::string& name) const { - return GetDim(op_.Output(name)); + std::vector ddims = GetOutputsDim(name); + PADDLE_ENFORCE_EQ(ddims.size(), 1UL, "Outputs(%s) length is not 1", name); + return ddims[0]; } void SetOutputDim(const std::string& name, const DDim& dim) { - SetDim(op_.Output(name), dim); + SetOutputsDim(name, {dim}); } - AttrReader Attrs() const { return AttrReader(op_.Attrs()); } + AttrReader Attrs() const { return AttrReader(op_.GetAttrMap()); } const std::vector& Inputs(const std::string& name) const { - return op_.Inputs(name); + return op_.Input(name); } const std::vector& Outputs(const std::string& name) const { - return op_.Outputs(name); + return op_.Output(name); } private: - template - Tensor* GetTensor(const std::string& name) const { - Tensor* t = nullptr; - auto* var = scope_.FindVar(name); - if (!var->IsType() && !var->IsType()) { - if (Allocate) { - t = var->GetMutable(); - } else { - PADDLE_THROW("Variable(%s) should be tensor", name); - } - } else { - t = GetTensorFromVar(scope_.FindVar(name)); - } - return t; - } - DDim GetDim(const std::string& name) const { - return GetTensor(name)->dims(); + return framework::make_ddim(block_.Var(name)->Shape()); } void SetDim(const std::string& name, const DDim& dim) { - GetTensor(name)->Resize(dim); + block_.Var(name)->SetShape(framework::vectorize(dim)); } - const OperatorBase& op_; - const Scope& scope_; + const OpDescBind& op_; + const BlockDescBind& block_; }; class RuntimeInferShapeContext : public InferShapeContextBase { -- GitLab