From dc8e0b494def7346248d0d1c02f64c7c0d1ed0d7 Mon Sep 17 00:00:00 2001 From: fengjiayi Date: Thu, 7 Jun 2018 19:45:52 +0800 Subject: [PATCH] fix bugs in the implementation of 'HasInput' and 'HasOutput' --- paddle/fluid/framework/operator.cc | 32 ++++++++++++++++++++++++++++++ paddle/fluid/framework/operator.h | 4 ++-- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index f87d552149..1aec2642e3 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -293,6 +293,38 @@ static Tensor* GetMutableTensorFromVar(Variable* var) { } } +bool ExecutionContext::HasInput(const std::string& name) const { + if (!op_.HasInputs(name)) { + return false; + } + auto& ins = Inputs(name); + size_t length = ins.size(); + if (length == 0) { + return false; + } + PADDLE_ENFORCE_EQ(length, 1UL, + "Input %s should not have more than one inputs", name); + auto arg = ins[0]; + auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg); + return var != nullptr; +} + +bool ExecutionContext::HasOutput(const std::string& name) const { + if (!op_.HasOutputs(name)) { + return false; + } + auto& outs = Outputs(name); + size_t length = outs.size(); + if (length == 0) { + return false; + } + PADDLE_ENFORCE_EQ(length, 1UL, + "Output %s should not have more than one inputs", name); + auto arg = outs[0]; + auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg); + return var != nullptr; +} + template <> const Tensor* ExecutionContext::Input(const std::string& name) const { auto* var = InputVar(name); diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index 2f480e00c1..b1d75d0d0f 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -191,9 +191,9 @@ class ExecutionContext { return op_.Attr(name); } - bool HasInput(const std::string& name) const { return op_.HasInputs(name); } + bool HasInput(const std::string& name) const; - bool HasOutput(const std::string& name) const { return op_.HasOutputs(name); } + bool HasOutput(const std::string& name) const; size_t InputSize(const std::string& name) const { return op_.Inputs(name).size(); -- GitLab