From 7d5bdbbfee3e0c5a4dba3057c6dfb028dbf28ab8 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Sun, 3 Sep 2017 12:28:18 -0700 Subject: [PATCH] Add GenerateTemporaryNames/CheckAllInputOutputSet --- paddle/framework/op_info.h | 14 ++++++++++-- paddle/framework/operator.cc | 41 +++++++++++++++++++++++++++--------- paddle/framework/operator.h | 4 ++++ 3 files changed, 47 insertions(+), 12 deletions(-) diff --git a/paddle/framework/op_info.h b/paddle/framework/op_info.h index 94245c6c44a..b98d8f23a14 100644 --- a/paddle/framework/op_info.h +++ b/paddle/framework/op_info.h @@ -80,9 +80,19 @@ class OpInfoMap { } const OpInfo& Get(const std::string& type) const { + auto op_info_ptr = GetNullable(type); + PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been registered", + type); + return *op_info_ptr; + } + + const OpInfo* GetNullable(const std::string& type) const { auto it = map_.find(type); - PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type); - return it->second; + if (it == map_.end()) { + return nullptr; + } else { + return &it->second; + } } template diff --git a/paddle/framework/operator.cc b/paddle/framework/operator.cc index 68fb469b561..88a872deaad 100644 --- a/paddle/framework/operator.cc +++ b/paddle/framework/operator.cc @@ -119,16 +119,8 @@ OperatorBase::OperatorBase(const std::string& type, const VariableNameMap& outputs, const AttributeMap& attrs) : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { - static std::atomic gUniqId(0UL); - for (auto& output : outputs_) { - for (auto& output_name : output.second) { - if (output_name == kTempVarName) { - output_name += type_; - output_name += "@"; - output_name += std::to_string(gUniqId.fetch_add(1)); - } - } - } + GenerateTemporaryNames(); + CheckAllInputOutputSet(); } std::vector OperatorBase::OutputVars(bool has_intermediate) const { @@ -156,6 +148,35 @@ std::vector OperatorBase::OutputVars(bool has_intermediate) const { return ret_val; } +void OperatorBase::CheckAllInputOutputSet() const { + auto& info_map = OpInfoMap::Instance(); + auto* op_info = info_map.GetNullable(Type()); + if (op_info == nullptr) return; + + for (auto& in : op_info->Proto().inputs()) { + PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(), + "input %s is not set", in.name()); + } + + for (auto& out : op_info->Proto().outputs()) { + PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(), + "output %s is not set", out.name()); + } +} + +void OperatorBase::GenerateTemporaryNames() { + static std::atomic gUniqId(0UL); + for (auto& output : outputs_) { + for (auto& output_name : output.second) { + if (output_name == kTempVarName) { + output_name += type_; + output_name += "@"; + output_name += std::to_string(gUniqId.fetch_add(1)); + } + } + } +} + void OpProtoAndCheckerMaker::Validate() { validated_ = true; CheckNoDuplicatedInOutAttrs(); diff --git a/paddle/framework/operator.h b/paddle/framework/operator.h index 60ca8b279e7..590e335fdc8 100644 --- a/paddle/framework/operator.h +++ b/paddle/framework/operator.h @@ -127,6 +127,10 @@ class OperatorBase { // IG (Inputs Gradients) VariableNameMap outputs_; AttributeMap attrs_; + + private: + void GenerateTemporaryNames(); + void CheckAllInputOutputSet() const; }; // Macro for define a clone method. -- GitLab