提交 7d5bdbbf 编写于 作者: Y Yu Yang

Add GenerateTemporaryNames/CheckAllInputOutputSet

上级 d7a1e40e
...@@ -80,9 +80,19 @@ class OpInfoMap { ...@@ -80,9 +80,19 @@ class OpInfoMap {
} }
const OpInfo& Get(const std::string& type) const { const OpInfo& Get(const std::string& type) const {
auto op_info_ptr = GetNullable(type);
PADDLE_ENFORCE_NOT_NULL(op_info_ptr, "Operator %s has not been registered",
type);
return *op_info_ptr;
}
const OpInfo* GetNullable(const std::string& type) const {
auto it = map_.find(type); auto it = map_.find(type);
PADDLE_ENFORCE(it != map_.end(), "Operator %s are not found", type); if (it == map_.end()) {
return it->second; return nullptr;
} else {
return &it->second;
}
} }
template <typename Callback> template <typename Callback>
......
...@@ -119,16 +119,8 @@ OperatorBase::OperatorBase(const std::string& type, ...@@ -119,16 +119,8 @@ OperatorBase::OperatorBase(const std::string& type,
const VariableNameMap& outputs, const VariableNameMap& outputs,
const AttributeMap& attrs) const AttributeMap& attrs)
: type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) { : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
static std::atomic<size_t> gUniqId(0UL); GenerateTemporaryNames();
for (auto& output : outputs_) { CheckAllInputOutputSet();
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
} }
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const { std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
...@@ -156,6 +148,35 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const { ...@@ -156,6 +148,35 @@ std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
return ret_val; return ret_val;
} }
void OperatorBase::CheckAllInputOutputSet() const {
auto& info_map = OpInfoMap::Instance();
auto* op_info = info_map.GetNullable(Type());
if (op_info == nullptr) return;
for (auto& in : op_info->Proto().inputs()) {
PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
"input %s is not set", in.name());
}
for (auto& out : op_info->Proto().outputs()) {
PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
"output %s is not set", out.name());
}
}
void OperatorBase::GenerateTemporaryNames() {
static std::atomic<size_t> gUniqId(0UL);
for (auto& output : outputs_) {
for (auto& output_name : output.second) {
if (output_name == kTempVarName) {
output_name += type_;
output_name += "@";
output_name += std::to_string(gUniqId.fetch_add(1));
}
}
}
}
void OpProtoAndCheckerMaker::Validate() { void OpProtoAndCheckerMaker::Validate() {
validated_ = true; validated_ = true;
CheckNoDuplicatedInOutAttrs(); CheckNoDuplicatedInOutAttrs();
......
...@@ -127,6 +127,10 @@ class OperatorBase { ...@@ -127,6 +127,10 @@ class OperatorBase {
// IG (Inputs Gradients) // IG (Inputs Gradients)
VariableNameMap outputs_; VariableNameMap outputs_;
AttributeMap attrs_; AttributeMap attrs_;
private:
void GenerateTemporaryNames();
void CheckAllInputOutputSet() const;
}; };
// Macro for define a clone method. // Macro for define a clone method.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册